Skip to content

Commit

Permalink
Fix exporting decoder model to onnx (#1264)
Browse files Browse the repository at this point in the history
* Use torch.jit.script() to export the decoder model

See also k2-fsa/sherpa-onnx#327
  • Loading branch information
csukuangfj authored Sep 22, 2023
1 parent f5dc957 commit 34e40a8
Show file tree
Hide file tree
Showing 17 changed files with 17 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def export_decoder_model_onnx(
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def export_decoder_model_onnx(
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
1 change: 1 addition & 0 deletions egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
1 change: 1 addition & 0 deletions egs/librispeech/ASR/zipformer/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size

y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
Expand Down

0 comments on commit 34e40a8

Please sign in to comment.