-
Notifications
You must be signed in to change notification settings - Fork 305
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Ready to merge] Pruned-transducer-stateless2 recipe for aidatatang_2…
…00zh (#375) * add pruned-rnnt2 model for aidatatang_200zh * do some changes * change for README.md * do some changes
- Loading branch information
1 parent
8c5722d
commit c8c8645
Showing
27 changed files
with
3,978 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
Note: This recipe is trained with the codes from this PR https://github.com/k2-fsa/icefall/pull/375 | ||
# Pre-trained Transducer-Stateless2 models for the Aidatatang_200zh dataset with icefall. | ||
The model was trained on full [Aidatatang_200zh](https://www.openslr.org/62) with the scripts in [icefall](https://github.com/k2-fsa/icefall) based on the latest version k2. | ||
## Training procedure | ||
The main repositories are list below, we will update the training and decoding scripts with the update of version. | ||
k2: https://github.com/k2-fsa/k2 | ||
icefall: https://github.com/k2-fsa/icefall | ||
lhotse: https://github.com/lhotse-speech/lhotse | ||
* Install k2 and lhotse, k2 installation guide refers to https://k2.readthedocs.io/en/latest/installation/index.html, lhotse refers to https://lhotse.readthedocs.io/en/latest/getting-started.html#installation. I think the latest version would be ok. And please also install the requirements listed in icefall. | ||
* Clone icefall(https://github.com/k2-fsa/icefall) and check to the commit showed above. | ||
``` | ||
git clone https://github.com/k2-fsa/icefall | ||
cd icefall | ||
``` | ||
* Preparing data. | ||
``` | ||
cd egs/aidatatang_200zh/ASR | ||
bash ./prepare.sh | ||
``` | ||
* Training | ||
``` | ||
export CUDA_VISIBLE_DEVICES="0,1" | ||
./pruned_transducer_stateless2/train.py \ | ||
--world-size 2 \ | ||
--num-epochs 30 \ | ||
--start-epoch 0 \ | ||
--exp-dir pruned_transducer_stateless2/exp \ | ||
--lang-dir data/lang_char \ | ||
--max-duration 250 | ||
``` | ||
## Evaluation results | ||
The decoding results (WER%) on Aidatatang_200zh(dev and test) are listed below, we got this result by averaging models from epoch 11 to 29. | ||
The WERs are | ||
| | dev | test | comment | | ||
|------------------------------------|------------|------------|------------------------------------------| | ||
| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 | | ||
| modified beam search (beam size 4) | 5.27 | 6.33 | --epoch 29, --avg 19, --max-duration 100 | | ||
| fast beam search (set as default) | 5.30 | 6.34 | --epoch 29, --avg 19, --max-duration 1500| |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
## Results | ||
|
||
### Aidatatang_200zh Char training results (Pruned Transducer Stateless2) | ||
|
||
#### 2022-05-16 | ||
|
||
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/375. | ||
|
||
The WERs are | ||
|
||
| | dev | test | comment | | ||
|------------------------------------|------------|------------|------------------------------------------| | ||
| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 | | ||
| modified beam search (beam size 4) | 5.27 | 6.33 | --epoch 29, --avg 19, --max-duration 100 | | ||
| fast beam search (set as default) | 5.30 | 6.34 | --epoch 29, --avg 19, --max-duration 1500| | ||
|
||
The training command for reproducing is given below: | ||
|
||
``` | ||
export CUDA_VISIBLE_DEVICES="0,1" | ||
./pruned_transducer_stateless2/train.py \ | ||
--world-size 2 \ | ||
--num-epochs 30 \ | ||
--start-epoch 0 \ | ||
--exp-dir pruned_transducer_stateless2/exp \ | ||
--lang-dir data/lang_char \ | ||
--max-duration 250 \ | ||
--save-every-n 1000 | ||
``` | ||
|
||
The tensorboard training log can be found at | ||
https://tensorboard.dev/experiment/xS7kgYf2RwyDpQAOdS8rAA/#scalars | ||
|
||
The decoding command is: | ||
``` | ||
epoch=29 | ||
avg=19 | ||
## greedy search | ||
./pruned_transducer_stateless2/decode.py \ | ||
--epoch $epoch \ | ||
--avg $avg \ | ||
--exp-dir pruned_transducer_stateless2/exp \ | ||
--lang-dir ./data/lang_char \ | ||
--max-duration 100 | ||
## modified beam search | ||
./pruned_transducer_stateless2/decode.py \ | ||
--epoch $epoch \ | ||
--avg $avg \ | ||
--exp-dir pruned_transducer_stateless2/exp \ | ||
--lang-dir ./data/lang_char \ | ||
--max-duration 100 \ | ||
--decoding-method modified_beam_search \ | ||
--beam-size 4 | ||
## fast beam search | ||
./pruned_transducer_stateless2/decode.py \ | ||
--epoch $epoch \ | ||
--avg $avg \ | ||
--exp-dir ./pruned_transducer_stateless2/exp \ | ||
--lang-dir ./data/lang_char \ | ||
--max-duration 1500 \ | ||
--decoding-method fast_beam_search \ | ||
--beam 4 \ | ||
--max-contexts 4 \ | ||
--max-states 8 | ||
``` | ||
|
||
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_aidatatang-200zh_pruned_transducer_stateless2> |
Empty file.
109 changes: 109 additions & 0 deletions
109
egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) | ||
# | ||
# See ../../../../LICENSE for clarification regarding multiple authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
""" | ||
This file computes fbank features of the aidatatang_200zh dataset. | ||
It looks for manifests in the directory data/manifests. | ||
The generated fbank features are saved in data/fbank. | ||
""" | ||
|
||
import argparse | ||
import logging | ||
import os | ||
from pathlib import Path | ||
|
||
import torch | ||
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer | ||
from lhotse.recipes.utils import read_manifests_if_cached | ||
|
||
from icefall.utils import get_executor | ||
|
||
# Torch's multithreaded behavior needs to be disabled or | ||
# it wastes a lot of CPU and slow things down. | ||
# Do this outside of main() in case it needs to take effect | ||
# even when we are not invoking the main (e.g. when spawning subprocesses). | ||
torch.set_num_threads(1) | ||
torch.set_num_interop_threads(1) | ||
|
||
|
||
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): | ||
src_dir = Path("data/manifests/aidatatang_200zh") | ||
output_dir = Path("data/fbank") | ||
num_jobs = min(15, os.cpu_count()) | ||
|
||
dataset_parts = ( | ||
"train", | ||
"dev", | ||
"test", | ||
) | ||
manifests = read_manifests_if_cached( | ||
dataset_parts=dataset_parts, output_dir=src_dir | ||
) | ||
assert manifests is not None | ||
|
||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) | ||
|
||
with get_executor() as ex: # Initialize the executor only once. | ||
for partition, m in manifests.items(): | ||
if (output_dir / f"cuts_{partition}.json.gz").is_file(): | ||
logging.info(f"{partition} already exists - skipping.") | ||
continue | ||
logging.info(f"Processing {partition}") | ||
cut_set = CutSet.from_manifests( | ||
recordings=m["recordings"], | ||
supervisions=m["supervisions"], | ||
) | ||
if "train" in partition: | ||
cut_set = ( | ||
cut_set | ||
+ cut_set.perturb_speed(0.9) | ||
+ cut_set.perturb_speed(1.1) | ||
) | ||
cut_set = cut_set.compute_and_store_features( | ||
extractor=extractor, | ||
storage_path=f"{output_dir}/feats_{partition}", | ||
# when an executor is specified, make more partitions | ||
num_jobs=num_jobs if ex is None else 80, | ||
executor=ex, | ||
storage_type=LilcomHdf5Writer, | ||
) | ||
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--num-mel-bins", | ||
type=int, | ||
default=80, | ||
help="""The number of mel bins for Fbank""", | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
if __name__ == "__main__": | ||
formatter = ( | ||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
) | ||
|
||
logging.basicConfig(format=formatter, level=logging.INFO) | ||
|
||
args = get_args() | ||
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../librispeech/ASR/local/compute_fbank_musan.py |
96 changes: 96 additions & 0 deletions
96
egs/aidatatang_200zh/ASR/local/display_manifest_statistics.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang | ||
# Mingshuang Luo) | ||
# | ||
# See ../../../../LICENSE for clarification regarding multiple authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
This file displays duration statistics of utterances in a manifest. | ||
You can use the displayed value to choose minimum/maximum duration | ||
to remove short and long utterances during the training. | ||
See the function `remove_short_and_long_utt()` | ||
in ../../../librispeech/ASR/transducer/train.py | ||
for usage. | ||
""" | ||
|
||
|
||
from lhotse import load_manifest | ||
|
||
|
||
def main(): | ||
paths = [ | ||
"./data/fbank/cuts_train.json.gz", | ||
"./data/fbank/cuts_dev.json.gz", | ||
"./data/fbank/cuts_test.json.gz", | ||
] | ||
|
||
for path in paths: | ||
print(f"Starting display the statistics for {path}") | ||
cuts = load_manifest(path) | ||
cuts.describe() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
""" | ||
Starting display the statistics for ./data/fbank/cuts_train.json.gz | ||
Cuts count: 494715 | ||
Total duration (hours): 422.6 | ||
Speech duration (hours): 422.6 (100.0%) | ||
*** | ||
Duration statistics (seconds): | ||
mean 3.1 | ||
std 1.2 | ||
min 1.0 | ||
25% 2.3 | ||
50% 2.7 | ||
75% 3.5 | ||
99% 7.2 | ||
99.5% 8.0 | ||
99.9% 9.5 | ||
max 18.1 | ||
Starting display the statistics for ./data/fbank/cuts_dev.json.gz | ||
Cuts count: 24216 | ||
Total duration (hours): 20.2 | ||
Speech duration (hours): 20.2 (100.0%) | ||
*** | ||
Duration statistics (seconds): | ||
mean 3.0 | ||
std 1.0 | ||
min 1.2 | ||
25% 2.3 | ||
50% 2.7 | ||
75% 3.4 | ||
99% 6.7 | ||
99.5% 7.3 | ||
99.9% 8.8 | ||
max 11.3 | ||
Starting display the statistics for ./data/fbank/cuts_test.json.gz | ||
Cuts count: 48144 | ||
Total duration (hours): 40.2 | ||
Speech duration (hours): 40.2 (100.0%) | ||
*** | ||
Duration statistics (seconds): | ||
mean 3.0 | ||
std 1.1 | ||
min 0.9 | ||
25% 2.3 | ||
50% 2.6 | ||
75% 3.4 | ||
99% 6.9 | ||
99.5% 7.5 | ||
99.9% 9.0 | ||
max 21.8 | ||
""" |
Oops, something went wrong.