-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update README.md and add notebook to quantize models
- Loading branch information
1 parent
6a25e10
commit 5a003a1
Showing
2 changed files
with
357 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,338 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6c0ca73e", | ||
"metadata": {}, | ||
"source": [ | ||
"This notebook contains hacks which allows to quantize and prepare models for mobile and test how fast they are working,\n", | ||
"but for the real use some extra work is needed.\n", | ||
"\n", | ||
"NeMo quantization cannot be used, because it exports models to TensorRT which does not work on android." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "bb2afaee", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from argparse import ArgumentParser\n", | ||
"from tqdm import tqdm\n", | ||
"\n", | ||
"import torch\n", | ||
"\n", | ||
"from nemo.collections.asr.models import EncDecCTCModel\n", | ||
"from nemo.collections.asr.models import EncDecCTCModelBPE\n", | ||
"from nemo.utils import logging\n", | ||
"from nemo.collections.asr.metrics.wer import WER, word_error_rate\n", | ||
"\n", | ||
"from nemo.collections.asr.parts.submodules.jasper import MaskedConv1d, _masked_conv_init_lens\n", | ||
"import torch.nn as nn\n", | ||
"\n", | ||
"import torch.quantization.quantize_fx as quantize_fx\n", | ||
"\n", | ||
"from nemo.core import typecheck\n", | ||
"typecheck.set_typecheck_enabled(False)\n", | ||
"\n", | ||
"#import warnings\n", | ||
"#warnings.simplefilter(\"ignore\", UserWarning)\n", | ||
"\n", | ||
"from nemo import __version__ as nemo_version" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6f07f316", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"c = '../nemo_experiments/Citrinet-256-8x-Stride-ru/2022-02-14_00-06-13/checkpoints/Citrinet-256-8x-Stride-ru.nemo'\n", | ||
"data = '/media/storage/yashkin/golos/train/1hour.jsonl'\n", | ||
"\n", | ||
"test_ds_l = ['/media/storage/yashkin/golos/test/crowd/test_crowd.jsonl', '/media/storage/yashkin/golos/test/farfield/test_farfield.jsonl']" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "5f31e6e5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"nemo_version" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6eb7d7ac", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class Minimal(nn.Module):\n", | ||
" def __init__(self, preprocessor, encoder, decoder):\n", | ||
" super().__init__()\n", | ||
" self.preprocessor = preprocessor\n", | ||
" self.encoder = encoder\n", | ||
" self.decoder = decoder\n", | ||
"\n", | ||
" def forward(self):\n", | ||
" input_signal, input_signal_length = torch.zeros(1, 16000), torch.Tensor([16000])\n", | ||
" pre = self.preprocessor(input_signal, input_signal_length)\n", | ||
" enc = self.encoder(*pre)\n", | ||
" dec = self.decoder(enc[0])\n", | ||
" return dec\n", | ||
"\n", | ||
"\n", | ||
"can_gpu = False # torch.cuda.is_available()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c3dfd72d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_forward(mask):\n", | ||
" def forward_t(x, lens):\n", | ||
" self = mask\n", | ||
" #if self.use_mask:\n", | ||
" # x, lens = self.update_masked_length(x, lens)\n", | ||
"\n", | ||
" # asymmtric pad if necessary\n", | ||
" if self.pad_layer is not None:\n", | ||
" x = self.pad_layer(x)\n", | ||
"\n", | ||
" sh = x.shape\n", | ||
" if self.heads != -1:\n", | ||
" x = x.view(-1, self.heads, sh[-1])\n", | ||
"\n", | ||
" out = self.conv(x)\n", | ||
"\n", | ||
" if self.heads != -1:\n", | ||
" out = out.view(sh[0], self.real_out_channels, -1)\n", | ||
"\n", | ||
" return out, lens\n", | ||
" return forward_t\n", | ||
" \n", | ||
"\n", | ||
"\n", | ||
"def make_traceable(model):\n", | ||
" assert nemo_version == '1.5.0', 'forward_t will not work with other NeMo versions'\n", | ||
" \n", | ||
" for name, module in model.named_children():\n", | ||
" if len(list(module.children())) > 0:\n", | ||
" ## compound module, go inside it\n", | ||
" make_traceable(module)\n", | ||
"\n", | ||
" if isinstance(module, MaskedConv1d):\n", | ||
" module.use_mask = False\n", | ||
" module.forward = get_forward(module)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1a36226d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# asr_model = EncDecCTCModelBPE.restore_from(c)\n", | ||
"asr_model = EncDecCTCModel.from_pretrained(\"QuartzNet15x5Base-En\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c3e3c60e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def setup_data(asr_model, vocabulary, data):\n", | ||
" asr_model.preprocessor.featurizer.pad_to = 0\n", | ||
" asr_model.preprocessor.featurizer.dither = 0.0\n", | ||
" \n", | ||
" asr_model.setup_test_data(\n", | ||
" test_data_config={\n", | ||
" 'sample_rate': 16000,\n", | ||
" 'manifest_filepath': data,\n", | ||
" 'labels': vocabulary,\n", | ||
" 'batch_size': 64,\n", | ||
" 'normalize_transcripts': False,\n", | ||
" 'num_workers': 4,\n", | ||
" }\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "59982a9f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"vocabulary = asr_model.decoder.vocabulary\n", | ||
"setup_data(asr_model, vocabulary, data)\n", | ||
"\n", | ||
"if can_gpu:\n", | ||
" asr_model = asr_model.cuda()\n", | ||
"\n", | ||
"make_traceable(asr_model)\n", | ||
"asr_model.eval();" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f5e8d6ec", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"minimal = Minimal(asr_model.preprocessor, asr_model.encoder, asr_model.decoder)\n", | ||
"minimal.cpu()\n", | ||
"# FIXME !!!\n", | ||
"trace = torch.jit.trace(minimal, (), check_trace=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "949fa2db", | ||
"metadata": {}, | ||
"source": [ | ||
"Save float version" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a5818943", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#trace._save_for_lite_interpreter(\"trace_q_f.ptl\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "aedc5080", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"qconfig_dict = qconfig_dict = {\"\": torch.quantization.get_default_qconfig('qnnpack')}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2d667524", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"asr_model.encoder = quantize_fx.prepare_fx(asr_model.encoder, qconfig_dict)\n", | ||
"asr_model.decoder = quantize_fx.prepare_fx(asr_model.decoder, qconfig_dict)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "fcadf17c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for test_batch in asr_model.test_dataloader():\n", | ||
" if can_gpu:\n", | ||
" test_batch = [x.cuda() for x in test_batch]\n", | ||
" log_probs, encoded_len, greedy_predictions = asr_model(\n", | ||
" input_signal=test_batch[0], input_signal_length=test_batch[1]\n", | ||
" )\n", | ||
" del test_batch\n", | ||
" break" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0f41021e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"asr_model.encoder = quantize_fx.convert_fx(asr_model.encoder)\n", | ||
"asr_model.decoder = quantize_fx.convert_fx(asr_model.decoder)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0a370aab", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"asr_model.cpu();" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "5ae2f4b2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"minimal = Minimal(asr_model.preprocessor, asr_model.encoder, asr_model.decoder)\n", | ||
"# FIXME !!!\n", | ||
"trace = torch.jit.trace(minimal, (), check_trace=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "edd8a859", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trace._save_for_lite_interpreter(\"trace_q_i.ptl\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "91f3d1fe", | ||
"metadata": {}, | ||
"source": [ | ||
"Save int8 version" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "136043c8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |