From ed4c74a210e005d8ed9e767a96b70b79271ab002 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 15 Apr 2021 16:29:33 +0800 Subject: [PATCH] Split an FsaVec into a batch of FsaVec to avoid CUDA OOM. --- snowfall/decoding/lm_rescore.py | 58 ++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/snowfall/decoding/lm_rescore.py b/snowfall/decoding/lm_rescore.py index 53fa51b5..7794390a 100644 --- a/snowfall/decoding/lm_rescore.py +++ b/snowfall/decoding/lm_rescore.py @@ -2,10 +2,52 @@ from typing import Optional +import math + import k2 import torch +def _intersect_device(a_fsas: k2.Fsa, b_fsas: k2.Fsa, b_to_a_map: torch.Tensor, + sorted_match_a: bool): + '''This is a wrapper of k2.intersect_device and its purpose is to split + b_fsas into several batches and process each batch separately to avoid + CUDA OOM error. + + The arguments and return value of this function are the same as + k2.intersect_device. + ''' + # NOTE: You can decrease batch_size in case of CUDA out of memory error. + batch_size = 500 + num_fsas = b_fsas.shape[0] + if num_fsas <= batch_size: + return k2.intersect_device(a_fsas, + b_fsas, + b_to_a_map=b_to_a_map, + sorted_match_a=sorted_match_a) + + num_batches = int(math.ceil(float(num_fsas) / batch_size)) + splits = [] + for i in range(num_batches): + start = i * batch_size + end = min(start + batch_size, num_fsas) + splits.append((start, end)) + + ans = [] + for start, end in splits: + indexes = torch.arange(start, end).to(b_to_a_map) + + fsas = k2.index(b_fsas, indexes) + b_to_a = k2.index(b_to_a_map, indexes) + path_lats = k2.intersect_device(a_fsas, + fsas, + b_to_a_map=b_to_a, + sorted_match_a=sorted_match_a) + ans.append(path_lats) + + return k2.append(ans) + + def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, path_to_seq_map: torch.Tensor) -> torch.Tensor: '''Compute AM scores of n-best lists (represented as word_fsas). @@ -44,10 +86,10 @@ def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, del inverted_lats.aux_labels inverted_lats = k2.arc_sort(inverted_lats) - am_path_lats = k2.intersect_device(inverted_lats, - word_fsas_with_epsilon_loops, - b_to_a_map=path_to_seq_map, - sorted_match_a=True) + am_path_lats = _intersect_device(inverted_lats, + word_fsas_with_epsilon_loops, + b_to_a_map=path_to_seq_map, + sorted_match_a=True) # NOTE: `k2.connect` and `k2.top_sort` support only CPU at present am_path_lats = k2.top_sort(k2.connect(am_path_lats.to('cpu'))).to(device) @@ -145,10 +187,10 @@ def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa, # Now compute lm_scores b_to_a_map = torch.zeros_like(path_to_seq_map) - lm_path_lats = k2.intersect_device(G, - word_fsas_with_epsilon_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True) + lm_path_lats = _intersect_device(G, + word_fsas_with_epsilon_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True) lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device) lm_scores = lm_path_lats.get_tot_scores(True, True)