Skip to content

Commit

Permalink
WIP: Add BPE training code.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jul 29, 2021
1 parent bd69e4b commit acc63a9
Show file tree
Hide file tree
Showing 15 changed files with 1,144 additions and 267 deletions.
602 changes: 602 additions & 0 deletions egs/librispeech/ASR/conformer_ctc/train.py

Large diffs are not rendered by default.

43 changes: 19 additions & 24 deletions egs/librispeech/ASR/conformer_ctc/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def decoder_forward(
supervision: Supervisions = None,
graph_compiler: object = None,
token_ids: List[int] = None,
sos_id: Optional[int] = None,
eos_id: Optional[int] = None,
) -> Tensor:
"""
Args:
Expand All @@ -197,6 +199,8 @@ def decoder_forward(
supervision: Supervison in lhotse format, get from batch['supervisions']
graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones)
, graph_compiler.words and graph_compiler.oov
sos_id: sos token id
eos_id: eos token id
Returns:
Tensor: Decoder loss.
Expand All @@ -206,18 +210,9 @@ def decoder_forward(
supervision, graph_compiler.lexicon.words, graph_compiler.oov
)
ys_in_pad, ys_out_pad = add_sos_eos(
batch_text,
graph_compiler.L_inv,
self.decoder_num_class - 1,
self.decoder_num_class - 1,
batch_text, graph_compiler.L_inv, sos_id, eos_id,
)
elif token_ids is not None:
# speical token ids:
# <blank> 0
# <UNK> 1
# <sos/eos> self.decoder_num_class - 1
sos_id = self.decoder_num_class - 1
eos_id = self.decoder_num_class - 1
_sos = torch.tensor([sos_id])
_eos = torch.tensor([eos_id])
ys_in = [
Expand Down Expand Up @@ -259,7 +254,12 @@ def decoder_forward(
return decoder_loss

def decoder_nll(
self, x: Tensor, encoder_mask: Tensor, token_ids: List[List[int]] = None
self,
x: Tensor,
encoder_mask: Tensor,
token_ids: List[List[int]],
sos_id: int,
eos_id: int,
) -> Tensor:
"""
Args:
Expand All @@ -273,12 +273,6 @@ def decoder_nll(
# The common part between this fuction and decoder_forward could be
# extracted as a seperated function.
if token_ids is not None:
# speical token ids:
# <blank> 0
# <UNK> 1
# <sos/eos> self.decoder_num_class - 1
sos_id = self.decoder_num_class - 1
eos_id = self.decoder_num_class - 1
_sos = torch.tensor([sos_id])
_eos = torch.tensor([eos_id])
ys_in = [
Expand Down Expand Up @@ -866,7 +860,8 @@ def forward(self, x: Tensor, target: Tensor) -> Tensor:
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
# denom = total if self.normalize_length else batch_size
denom = total if self.normalize_length else 1
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom


Expand Down Expand Up @@ -983,26 +978,26 @@ def generate_square_subsequent_mask(sz: int) -> Tensor:
def add_sos_eos(
ys: List[List[int]],
lexicon: k2.Fsa,
sos: int,
eos: int,
sos_id: int,
eos_id: int,
ignore_id: int = -1,
) -> Tuple[Tensor, Tensor]:
"""Add <sos> and <eos> labels.
Args:
ys: batch of unpadded target sequences
lexicon: Its labels are words, while its aux_labels are phones.
sos: index of <sos>
eos: index of <eos>
sos_id: index of <sos>
eos_id: index of <eos>
ignore_id: index of padding
Returns:
Tensor: Input of transformer decoder. Padded tensor of dimention (batch_size, max_length).
Tensor: Output of transformer decoder. padded tensor of dimention (batch_size, max_length).
"""

_sos = torch.tensor([sos])
_eos = torch.tensor([eos])
_sos = torch.tensor([sos_id])
_eos = torch.tensor([eos_id])
ys = get_hierarchical_targets(ys, lexicon)
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
Expand Down
57 changes: 32 additions & 25 deletions egs/librispeech/ASR/local/compile_hlg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
This script compiles HLG from
- H, the ctc topology, built from phones contained in lexicon.txt
- H, the ctc topology, built from tokens contained in lexicon.txt
- L, the lexicon, built from L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
Expand All @@ -13,6 +13,7 @@
The generated HLG is saved in data/lm/HLG.pt (phone based)
or data/lm/HLG_bpe.pt (BPE based)
"""
import logging
from pathlib import Path

import k2
Expand All @@ -32,72 +33,72 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
print(f"Building ctc_topo. max_token_id: {max_token_id}")
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))

if Path("data/lm/G_3_gram.pt").is_file():
print("Loading pre-compiled G_3_gram")
logging.info("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
G = k2.Fsa.from_dict(d)
else:
print("Loading G_3_gram.fst.txt")
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "G_3_gram.pt")

first_token_disambig_id = lexicon.phones["#0"]
first_word_disambig_id = lexicon.words["#0"]
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]

L = k2.arc_sort(L)
G = k2.arc_sort(G)

print("Intersecting L and G")
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
print(f"LG shape: {LG.shape}")
logging.info(f"LG shape: {LG.shape}")

print("Connecting LG")
logging.info("Connecting LG")
LG = k2.connect(LG)
print(f"LG shape after k2.connect: {LG.shape}")
logging.info(f"LG shape after k2.connect: {LG.shape}")

print(type(LG.aux_labels))
print("Determinizing LG")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")

LG = k2.determinize(LG)
print(type(LG.aux_labels))
logging.info(type(LG.aux_labels))

print("Connecting LG after k2.determinize")
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)

print("Removing disambiguation symbols on LG")
logging.info("Removing disambiguation symbols on LG")

LG.labels[LG.labels >= first_token_disambig_id] = 0

assert isinstance(LG.aux_labels, k2.RaggedInt)
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0

LG = k2.remove_epsilon(LG)
print(f"LG shape after k2.remove_epsilon: {LG.shape}")
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")

LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

print("Arc sorting LG")
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)

print("Composing H and LG")
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")

print("Connecting LG")
logging.info("Connecting LG")
HLG = k2.connect(HLG)

print("Arc sorting LG")
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
print(f"HLG.shape: {HLG.shape}")
logging.info(f"HLG.shape: {HLG.shape}")

return HLG

Expand All @@ -106,20 +107,20 @@ def phone_based_HLG():
if Path("data/lm/HLG.pt").is_file():
return

print("Compiling phone based HLG")
logging.info("Compiling phone based HLG")
HLG = compile_HLG("data/lang")

print("Saving HLG.pt to data/lm")
logging.info("Saving HLG.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG.pt")


def bpe_based_HLG():
if Path("data/lm/HLG_bpe.pt").is_file():
return

print("Compiling BPE based HLG")
logging.info("Compiling BPE based HLG")
HLG = compile_HLG("data/lang/bpe")
print("Saving HLG_bpe.pt to data/lm")
logging.info("Saving HLG_bpe.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt")


Expand All @@ -129,4 +130,10 @@ def main():


if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)

logging.basicConfig(format=formatter, level=logging.INFO)

main()
Loading

0 comments on commit acc63a9

Please sign in to comment.