Skip to content

Commit

Permalink
Update README.md and add notebook to quantize models
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreyYashkin committed May 15, 2022
1 parent 6a25e10 commit 5a003a1
Show file tree
Hide file tree
Showing 2 changed files with 357 additions and 3 deletions.
22 changes: 19 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ If you are not planing to train, then you can download Golos in opus format inst
```
python3 datasets/get_commonvoice_data.py --data_root data/mcv
```
If you are going to use pretrained weights then download "STT En Citrinet 256" from NVIDIA NGC.

Replace "ё" symbol with "е" as it is done in Golos
```
sed -i 's/u0451/u0435/' data/mcv/commonvoice_dev_manifest.json
sed -i 's/u0451/u0435/' data/mcv/commonvoice_test_manifest.json
sed -i 's/u0451/u0435/' data/mcv/commonvoice_train_manifest.json
```

## Training
Create word piece tokenization
Expand All @@ -42,7 +48,8 @@ Check that it is posible to compute CTC loss for the most of samples.
```
python3 ctc_loss_check.py --config-name=finetune_citrinet_256_eng
```
Finetune the pretrained english model
Finetune the pretrained english model after you download "STT En Citrinet 256" from NVIDIA NGC and
put in `nemo_experiments/stt_en_citrinet_256.nemo`
```
python3 speech_to_text_finetune.py --config-name=finetune_citrinet_256_eng
```
Expand All @@ -52,7 +59,16 @@ python3 speech_to_text_ctc_bpe.py --config-path=conf --config-name=citrinet_256_
```

## Getting metrics

To compute metrics for data in `$MANIFEST_PATH`
```
python speech_to_text_eval.py model_path=nemo_experiments/Citrinet-256-8x-Stride-ru/.../checkpoints/Citrinet-256-8x-Stride-ru.nemo dataset_manifest="$MANIFEST_PATH"
```

## Check mobile performance (Android)
To convert model to format that can be used on mobile see notebook in `mobile`.
[Follow pytorch tutorial to measure Android performance](https://pytorch.org/tutorials/recipes/mobile_perf.html).
Make sure that you mobile torch version is built with [fft support](https://github.com/pytorch/pytorch/commit/5045c18bd1214d8bdb3dc41306da9de06868aecb)
```
adb push trace_c_i.ts /data/local/tmp
adb shell "/data/local/tmp/speed_benchmark_torch --model=/data/local/tmp/trace_c_i.ptl" --no_inputs true --iter 25
```
338 changes: 338 additions & 0 deletions mobile/prepare for mobile.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "6c0ca73e",
"metadata": {},
"source": [
"This notebook contains hacks which allows to quantize and prepare models for mobile and test how fast they are working,\n",
"but for the real use some extra work is needed.\n",
"\n",
"NeMo quantization cannot be used, because it exports models to TensorRT which does not work on android."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb2afaee",
"metadata": {},
"outputs": [],
"source": [
"from argparse import ArgumentParser\n",
"from tqdm import tqdm\n",
"\n",
"import torch\n",
"\n",
"from nemo.collections.asr.models import EncDecCTCModel\n",
"from nemo.collections.asr.models import EncDecCTCModelBPE\n",
"from nemo.utils import logging\n",
"from nemo.collections.asr.metrics.wer import WER, word_error_rate\n",
"\n",
"from nemo.collections.asr.parts.submodules.jasper import MaskedConv1d, _masked_conv_init_lens\n",
"import torch.nn as nn\n",
"\n",
"import torch.quantization.quantize_fx as quantize_fx\n",
"\n",
"from nemo.core import typecheck\n",
"typecheck.set_typecheck_enabled(False)\n",
"\n",
"#import warnings\n",
"#warnings.simplefilter(\"ignore\", UserWarning)\n",
"\n",
"from nemo import __version__ as nemo_version"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f07f316",
"metadata": {},
"outputs": [],
"source": [
"c = '../nemo_experiments/Citrinet-256-8x-Stride-ru/2022-02-14_00-06-13/checkpoints/Citrinet-256-8x-Stride-ru.nemo'\n",
"data = '/media/storage/yashkin/golos/train/1hour.jsonl'\n",
"\n",
"test_ds_l = ['/media/storage/yashkin/golos/test/crowd/test_crowd.jsonl', '/media/storage/yashkin/golos/test/farfield/test_farfield.jsonl']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f31e6e5",
"metadata": {},
"outputs": [],
"source": [
"nemo_version"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6eb7d7ac",
"metadata": {},
"outputs": [],
"source": [
"class Minimal(nn.Module):\n",
" def __init__(self, preprocessor, encoder, decoder):\n",
" super().__init__()\n",
" self.preprocessor = preprocessor\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
"\n",
" def forward(self):\n",
" input_signal, input_signal_length = torch.zeros(1, 16000), torch.Tensor([16000])\n",
" pre = self.preprocessor(input_signal, input_signal_length)\n",
" enc = self.encoder(*pre)\n",
" dec = self.decoder(enc[0])\n",
" return dec\n",
"\n",
"\n",
"can_gpu = False # torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3dfd72d",
"metadata": {},
"outputs": [],
"source": [
"def get_forward(mask):\n",
" def forward_t(x, lens):\n",
" self = mask\n",
" #if self.use_mask:\n",
" # x, lens = self.update_masked_length(x, lens)\n",
"\n",
" # asymmtric pad if necessary\n",
" if self.pad_layer is not None:\n",
" x = self.pad_layer(x)\n",
"\n",
" sh = x.shape\n",
" if self.heads != -1:\n",
" x = x.view(-1, self.heads, sh[-1])\n",
"\n",
" out = self.conv(x)\n",
"\n",
" if self.heads != -1:\n",
" out = out.view(sh[0], self.real_out_channels, -1)\n",
"\n",
" return out, lens\n",
" return forward_t\n",
" \n",
"\n",
"\n",
"def make_traceable(model):\n",
" assert nemo_version == '1.5.0', 'forward_t will not work with other NeMo versions'\n",
" \n",
" for name, module in model.named_children():\n",
" if len(list(module.children())) > 0:\n",
" ## compound module, go inside it\n",
" make_traceable(module)\n",
"\n",
" if isinstance(module, MaskedConv1d):\n",
" module.use_mask = False\n",
" module.forward = get_forward(module)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a36226d",
"metadata": {},
"outputs": [],
"source": [
"# asr_model = EncDecCTCModelBPE.restore_from(c)\n",
"asr_model = EncDecCTCModel.from_pretrained(\"QuartzNet15x5Base-En\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3e3c60e",
"metadata": {},
"outputs": [],
"source": [
"def setup_data(asr_model, vocabulary, data):\n",
" asr_model.preprocessor.featurizer.pad_to = 0\n",
" asr_model.preprocessor.featurizer.dither = 0.0\n",
" \n",
" asr_model.setup_test_data(\n",
" test_data_config={\n",
" 'sample_rate': 16000,\n",
" 'manifest_filepath': data,\n",
" 'labels': vocabulary,\n",
" 'batch_size': 64,\n",
" 'normalize_transcripts': False,\n",
" 'num_workers': 4,\n",
" }\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59982a9f",
"metadata": {},
"outputs": [],
"source": [
"vocabulary = asr_model.decoder.vocabulary\n",
"setup_data(asr_model, vocabulary, data)\n",
"\n",
"if can_gpu:\n",
" asr_model = asr_model.cuda()\n",
"\n",
"make_traceable(asr_model)\n",
"asr_model.eval();"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5e8d6ec",
"metadata": {},
"outputs": [],
"source": [
"minimal = Minimal(asr_model.preprocessor, asr_model.encoder, asr_model.decoder)\n",
"minimal.cpu()\n",
"# FIXME !!!\n",
"trace = torch.jit.trace(minimal, (), check_trace=False)"
]
},
{
"cell_type": "markdown",
"id": "949fa2db",
"metadata": {},
"source": [
"Save float version"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5818943",
"metadata": {},
"outputs": [],
"source": [
"#trace._save_for_lite_interpreter(\"trace_q_f.ptl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aedc5080",
"metadata": {},
"outputs": [],
"source": [
"qconfig_dict = qconfig_dict = {\"\": torch.quantization.get_default_qconfig('qnnpack')}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d667524",
"metadata": {},
"outputs": [],
"source": [
"asr_model.encoder = quantize_fx.prepare_fx(asr_model.encoder, qconfig_dict)\n",
"asr_model.decoder = quantize_fx.prepare_fx(asr_model.decoder, qconfig_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fcadf17c",
"metadata": {},
"outputs": [],
"source": [
"for test_batch in asr_model.test_dataloader():\n",
" if can_gpu:\n",
" test_batch = [x.cuda() for x in test_batch]\n",
" log_probs, encoded_len, greedy_predictions = asr_model(\n",
" input_signal=test_batch[0], input_signal_length=test_batch[1]\n",
" )\n",
" del test_batch\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f41021e",
"metadata": {},
"outputs": [],
"source": [
"asr_model.encoder = quantize_fx.convert_fx(asr_model.encoder)\n",
"asr_model.decoder = quantize_fx.convert_fx(asr_model.decoder)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a370aab",
"metadata": {},
"outputs": [],
"source": [
"asr_model.cpu();"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ae2f4b2",
"metadata": {},
"outputs": [],
"source": [
"minimal = Minimal(asr_model.preprocessor, asr_model.encoder, asr_model.decoder)\n",
"# FIXME !!!\n",
"trace = torch.jit.trace(minimal, (), check_trace=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "edd8a859",
"metadata": {},
"outputs": [],
"source": [
"trace._save_for_lite_interpreter(\"trace_q_i.ptl\")"
]
},
{
"cell_type": "markdown",
"id": "91f3d1fe",
"metadata": {},
"source": [
"Save int8 version"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "136043c8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 5a003a1

Please sign in to comment.