-
Notifications
You must be signed in to change notification settings - Fork 305
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Ready to merge]stateless6: states4 + hubert distillation. (#387)
* a copy of stateless4 as base * distillation with hubert * fix typo * example usage * usage * Update egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py Co-authored-by: Fangjun Kuang <[email protected]> * fix comment * add results of 100hours * Update egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py Co-authored-by: Fangjun Kuang <[email protected]> * Update egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py Co-authored-by: Fangjun Kuang <[email protected]> * check fairseq and quantization * a short intro to distillation framework * Update egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py Co-authored-by: Fangjun Kuang <[email protected]> * add intro of statless6 in README * fix type error of dst_manifest_dir * Update egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py Co-authored-by: Fangjun Kuang <[email protected]> * make export.py call stateless6/train.py instead of stateless2/train.py * update results by stateless6 * adjust results format * fix typo Co-authored-by: Fangjun Kuang <[email protected]>
- Loading branch information
1 parent
c8c8645
commit c4ee2bc
Showing
23 changed files
with
4,429 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# A short introduction about distillation framework. | ||
# | ||
# A typical traditional distillation method is | ||
# Loss(teacher embedding, student embedding). | ||
# | ||
# Comparing to these, the proposed distillation framework contains two mainly steps: | ||
# codebook indexes = quantizer.encode(teacher embedding) | ||
# Loss(codebook indexes, student embedding) | ||
# | ||
# Things worth to meantion: | ||
# 1. The float type teacher embedding is quantized into a sequence of | ||
# 8-bit integer codebook indexes. | ||
# 2. a middle layer 36(1-based) out of total 48 layers is used to extract | ||
# teacher embeddings. | ||
# 3. a middle layer 6(1-based) out of total 6 layers is used to extract | ||
# student embeddings. | ||
|
||
# This is an example to do distillation with librispeech clean-100 subset. | ||
# run with command: | ||
# bash distillation_with_hubert.sh [0|1|2|3|4] | ||
# | ||
# For example command | ||
# bash distillation_with_hubert.sh 0 | ||
# will download hubert model. | ||
stage=$1 | ||
|
||
# Set the GPUs available. | ||
# This script requires at least one GPU. | ||
# You MUST set environment variable "CUDA_VISIBLE_DEVICES", | ||
# even you only have ONE GPU. It needed by CodebookIndexExtractor to determine numbert of jobs to extract codebook indexes parallelly. | ||
|
||
# Suppose only one GPU exists: | ||
# export CUDA_VISIBLE_DEVICES="0" | ||
# | ||
# Suppose GPU 2,3,4,5 are available. | ||
export CUDA_VISIBLE_DEVICES="2,3,4,5" | ||
|
||
|
||
if [ $stage -eq 0 ]; then | ||
# Preparation stage. | ||
|
||
# Install fairseq according to: | ||
# https://github.com/pytorch/fairseq | ||
# when testing this code: | ||
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used. | ||
has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)") | ||
if [ $has_fairseq == 'False' ]; then | ||
echo "Please install fairseq before running following stages" | ||
exit 1 | ||
fi | ||
|
||
# Install quantization toolkit: | ||
# pip install git+https://github.com/danpovey/quantization.git@master | ||
# when testing this code: | ||
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used. | ||
|
||
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)") | ||
if [ $has_quantization == 'False' ]; then | ||
echo "Please install quantization before running following stages" | ||
exit 1 | ||
fi | ||
|
||
echo "Download hubert model." | ||
# Parameters about model. | ||
exp_dir=./pruned_transducer_stateless6/exp/ | ||
model_id=hubert_xtralarge_ll60k_finetune_ls960 | ||
hubert_model_dir=${exp_dir}/hubert_models | ||
hubert_model=${hubert_model_dir}/${model_id}.pt | ||
mkdir -p ${hubert_model_dir} | ||
# For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert | ||
if [ -f ${hubert_model} ]; then | ||
echo "hubert model alread exists." | ||
else | ||
wget -c https://dl.fbaipublicfiles.com/hubert/${model_id} -P ${hubert_model} | ||
wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir} | ||
fi | ||
fi | ||
|
||
if [ ! -d ./data/fbank ]; then | ||
echo "This script assumes ./data/fbank is already generated by prepare.sh" | ||
exit 1 | ||
fi | ||
|
||
if [ $stage -eq 1 ]; then | ||
# This stage is not directly used by codebook indexes extraction. | ||
# It is a method to "prove" that the downloaed hubert model | ||
# is inferenced in an correct way if WERs look like normal. | ||
# Expect WERs: | ||
# [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ] | ||
# [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ] | ||
./pruned_transducer_stateless6/hubert_decode.py | ||
fi | ||
|
||
if [ $stage -eq 2 ]; then | ||
# Analysis of disk usage: | ||
# With num_codebooks==8, each teacher embedding is quantized into | ||
# a sequence of eight 8-bit integers, i.e. only eight bytes are needed. | ||
# Training dataset including clean-100h with speed perturb 0.9 and 1.1 has 300 hours. | ||
# The output frame rates of Hubert is 50 per second. | ||
# Theoretically, 412M = 300 * 3600 * 50 * 8 / 1024 / 1024 is needed. | ||
# The actual size of all "*.h5" files storaging codebook index is 450M. | ||
# I think the extra "48M" usage is some meta information. | ||
|
||
# Time consumption analysis: | ||
# For quantizer training data(teacher embedding) extraction, only 1000 utts from clean-100 are used. | ||
# Together with quantizer training, no more than 20 minutes will be used. | ||
# | ||
# For codebook indexes extraction, | ||
# with two pieces of NVIDIA A100 gpus, around three hours needed to process 300 hours training data, | ||
# i.e. clean-100 with speed purteb 0.9 and 1.1. | ||
|
||
# GPU usage: | ||
# During quantizer's training data(teacher embedding) and it's training, | ||
# only the first ONE GPU is used. | ||
# During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used. | ||
./pruned_transducer_stateless6/extract_codebook_index.py \ | ||
--full-libri False | ||
fi | ||
|
||
if [ $stage -eq 3 ]; then | ||
# Example training script. | ||
# Note: it's better to set spec-aug-time-warpi-factor=-1 | ||
WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}') | ||
./pruned_transducer_stateless6/train.py \ | ||
--manifest-dir ./data/vq_fbank \ | ||
--master-port 12359 \ | ||
--full-libri False \ | ||
--spec-aug-time-warp-factor -1 \ | ||
--max-duration 300 \ | ||
--world-size ${WORLD_SIZE} \ | ||
--num-epochs 20 | ||
fi | ||
|
||
if [ $stage -eq 4 ]; then | ||
# Results should be similar to: | ||
# errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67 | ||
# errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60 | ||
./pruned_transducer_stateless6/decode.py \ | ||
--decoding-method "modified_beam_search" \ | ||
--epoch 20 \ | ||
--avg 10 \ | ||
--max-duration 200 \ | ||
--exp-dir ./pruned_transducer_stateless6/exp | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../pruned_transducer_stateless2/__init__.py |
1 change: 1 addition & 0 deletions
1
egs/librispeech/ASR/pruned_transducer_stateless6/asr_datamodule.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../pruned_transducer_stateless2/asr_datamodule.py |
1 change: 1 addition & 0 deletions
1
egs/librispeech/ASR/pruned_transducer_stateless6/beam_search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../pruned_transducer_stateless2/beam_search.py |
Oops, something went wrong.