diff --git a/snowfall/training/ctc_graph.py b/snowfall/training/ctc_graph.py index ea4b198a..48c06fee 100644 --- a/snowfall/training/ctc_graph.py +++ b/snowfall/training/ctc_graph.py @@ -105,23 +105,20 @@ def __init__(self, phone_ids_with_blank = [0] + phone_ids self.ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) - def compile(self, texts: Iterable[str]) -> k2.Fsa: - decoding_graphs = k2.create_fsa_vec( - [self.compile_one_and_cache(text) for text in texts]) - # make sure the gradient is not accumulated - decoding_graphs.requires_grad_(False) - return decoding_graphs - - @lru_cache(maxsize=100000) - def compile_one_and_cache(self, text: str) -> k2.Fsa: - tokens = (token if token in self.words else self.oov - for token in text.split(' ')) - word_ids = [self.words[token] for token in tokens] + def compile(self, texts: Iterable[str]) ->k2.Fsa: + word_ids = [] + for text in texts: + tokens = (token if token in self.words else self.oov + for token in text.split(' ')) + word_id = [self.words[token] for token in tokens] + word_ids.append(word_id) label_graph = k2.linear_fsa(word_ids) decoding_graph = k2.connect(k2.intersect(label_graph, self.L_inv)).invert_() decoding_graph = k2.arc_sort(decoding_graph) decoding_graph = k2.compose(self.ctc_topo, decoding_graph) decoding_graph = k2.connect(decoding_graph) + # make sure the gradient is not accumulated + decoding_graph.requires_grad_(False) return decoding_graph