Skip to content

Commit

Permalink
[Ready to merge] Pruned-transducer-stateless2 recipe for aidatatang_2…
Browse files Browse the repository at this point in the history
…00zh (#375)

* add pruned-rnnt2 model for aidatatang_200zh

* do some changes

* change for README.md

* do some changes
  • Loading branch information
luomingshuang authored May 24, 2022
1 parent 8c5722d commit c8c8645
Show file tree
Hide file tree
Showing 27 changed files with 3,978 additions and 0 deletions.
38 changes: 38 additions & 0 deletions egs/aidatatang_200zh/ASR/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Note: This recipe is trained with the codes from this PR https://github.com/k2-fsa/icefall/pull/375
# Pre-trained Transducer-Stateless2 models for the Aidatatang_200zh dataset with icefall.
The model was trained on full [Aidatatang_200zh](https://www.openslr.org/62) with the scripts in [icefall](https://github.com/k2-fsa/icefall) based on the latest version k2.
## Training procedure
The main repositories are list below, we will update the training and decoding scripts with the update of version.
k2: https://github.com/k2-fsa/k2
icefall: https://github.com/k2-fsa/icefall
lhotse: https://github.com/lhotse-speech/lhotse
* Install k2 and lhotse, k2 installation guide refers to https://k2.readthedocs.io/en/latest/installation/index.html, lhotse refers to https://lhotse.readthedocs.io/en/latest/getting-started.html#installation. I think the latest version would be ok. And please also install the requirements listed in icefall.
* Clone icefall(https://github.com/k2-fsa/icefall) and check to the commit showed above.
```
git clone https://github.com/k2-fsa/icefall
cd icefall
```
* Preparing data.
```
cd egs/aidatatang_200zh/ASR
bash ./prepare.sh
```
* Training
```
export CUDA_VISIBLE_DEVICES="0,1"
./pruned_transducer_stateless2/train.py \
--world-size 2 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 250
```
## Evaluation results
The decoding results (WER%) on Aidatatang_200zh(dev and test) are listed below, we got this result by averaging models from epoch 11 to 29.
The WERs are
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 |
| modified beam search (beam size 4) | 5.27 | 6.33 | --epoch 29, --avg 19, --max-duration 100 |
| fast beam search (set as default) | 5.30 | 6.34 | --epoch 29, --avg 19, --max-duration 1500|
72 changes: 72 additions & 0 deletions egs/aidatatang_200zh/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
## Results

### Aidatatang_200zh Char training results (Pruned Transducer Stateless2)

#### 2022-05-16

Using the codes from this PR https://github.com/k2-fsa/icefall/pull/375.

The WERs are

| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 |
| modified beam search (beam size 4) | 5.27 | 6.33 | --epoch 29, --avg 19, --max-duration 100 |
| fast beam search (set as default) | 5.30 | 6.34 | --epoch 29, --avg 19, --max-duration 1500|

The training command for reproducing is given below:

```
export CUDA_VISIBLE_DEVICES="0,1"
./pruned_transducer_stateless2/train.py \
--world-size 2 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 250 \
--save-every-n 1000
```

The tensorboard training log can be found at
https://tensorboard.dev/experiment/xS7kgYf2RwyDpQAOdS8rAA/#scalars

The decoding command is:
```
epoch=29
avg=19
## greedy search
./pruned_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir ./data/lang_char \
--max-duration 100
## modified beam search
./pruned_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir ./data/lang_char \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
## fast beam search
./pruned_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir ./data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```

A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_aidatatang-200zh_pruned_transducer_stateless2>
Empty file.
109 changes: 109 additions & 0 deletions egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This file computes fbank features of the aidatatang_200zh dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""

import argparse
import logging
import os
from pathlib import Path

import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor

# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)


def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
src_dir = Path("data/manifests/aidatatang_200zh")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())

dataset_parts = (
"train",
"dev",
"test",
)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, output_dir=src_dir
)
assert manifests is not None

extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))

with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomHdf5Writer,
)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)

return parser.parse_args()


if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)

logging.basicConfig(format=formatter, level=logging.INFO)

args = get_args()
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
1 change: 1 addition & 0 deletions egs/aidatatang_200zh/ASR/local/compute_fbank_musan.py
96 changes: 96 additions & 0 deletions egs/aidatatang_200zh/ASR/local/display_manifest_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()`
in ../../../librispeech/ASR/transducer/train.py
for usage.
"""


from lhotse import load_manifest


def main():
paths = [
"./data/fbank/cuts_train.json.gz",
"./data/fbank/cuts_dev.json.gz",
"./data/fbank/cuts_test.json.gz",
]

for path in paths:
print(f"Starting display the statistics for {path}")
cuts = load_manifest(path)
cuts.describe()


if __name__ == "__main__":
main()

"""
Starting display the statistics for ./data/fbank/cuts_train.json.gz
Cuts count: 494715
Total duration (hours): 422.6
Speech duration (hours): 422.6 (100.0%)
***
Duration statistics (seconds):
mean 3.1
std 1.2
min 1.0
25% 2.3
50% 2.7
75% 3.5
99% 7.2
99.5% 8.0
99.9% 9.5
max 18.1
Starting display the statistics for ./data/fbank/cuts_dev.json.gz
Cuts count: 24216
Total duration (hours): 20.2
Speech duration (hours): 20.2 (100.0%)
***
Duration statistics (seconds):
mean 3.0
std 1.0
min 1.2
25% 2.3
50% 2.7
75% 3.4
99% 6.7
99.5% 7.3
99.9% 8.8
max 11.3
Starting display the statistics for ./data/fbank/cuts_test.json.gz
Cuts count: 48144
Total duration (hours): 40.2
Speech duration (hours): 40.2 (100.0%)
***
Duration statistics (seconds):
mean 3.0
std 1.1
min 0.9
25% 2.3
50% 2.6
75% 3.4
99% 6.9
99.5% 7.5
99.9% 9.0
max 21.8
"""
Loading

0 comments on commit c8c8645

Please sign in to comment.