Skip to content

Commit

Permalink
Add onnx export support for pruned_transducer_stateless5 (#883)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Feb 7, 2023
1 parent ffbf6d9 commit 7ae03f6
Show file tree
Hide file tree
Showing 6 changed files with 663 additions and 11 deletions.
72 changes: 72 additions & 0 deletions .github/scripts/test-onnx-export.sh
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,75 @@ log "Run onnx_pretrained.py"
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

rm -rf $repo
log "--------------------------------------------------------------------------"


log "=========================================================================="
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)

pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-39-avg-7.pt"

cd exp
ln -s pretrained-epoch-39-avg-7.pt epoch-99.pt
popd

log "Export via torch.jit.script()"

./pruned_transducer_stateless5/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--exp-dir $repo/exp \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512 \
--jit 1

log "Test exporting to ONNX format"

./pruned_transducer_stateless5/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--exp-dir $repo/exp \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512

ls -lh $repo/exp

log "Run onnx_check.py"

./pruned_transducer_stateless5/onnx_check.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx

log "Run onnx_pretrained.py"

./pruned_transducer_stateless5/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

rm -rf $repo
log "--------------------------------------------------------------------------"
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
as an example to show how to use this file.
1. Download the pre-trained model
Expand Down
33 changes: 23 additions & 10 deletions egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from torch import Tensor, nn

from icefall.utils import make_pad_mask, subsequent_chunk_mask
from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask


class Conformer(EncoderInterface):
Expand Down Expand Up @@ -1012,15 +1012,28 @@ def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)

if is_jit_tracing():
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(time2)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols

x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, time1, time2)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)

def multi_head_attention_forward(
self,
Expand Down
Loading

0 comments on commit 7ae03f6

Please sign in to comment.