Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
Split an FsaVec into a batch of FsaVec to avoid CUDA OOM.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Apr 15, 2021
1 parent 16e9f59 commit ed4c74a
Showing 1 changed file with 50 additions and 8 deletions.
58 changes: 50 additions & 8 deletions snowfall/decoding/lm_rescore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,52 @@

from typing import Optional

import math

import k2
import torch


def _intersect_device(a_fsas: k2.Fsa, b_fsas: k2.Fsa, b_to_a_map: torch.Tensor,
sorted_match_a: bool):
'''This is a wrapper of k2.intersect_device and its purpose is to split
b_fsas into several batches and process each batch separately to avoid
CUDA OOM error.
The arguments and return value of this function are the same as
k2.intersect_device.
'''
# NOTE: You can decrease batch_size in case of CUDA out of memory error.
batch_size = 500
num_fsas = b_fsas.shape[0]
if num_fsas <= batch_size:
return k2.intersect_device(a_fsas,
b_fsas,
b_to_a_map=b_to_a_map,
sorted_match_a=sorted_match_a)

num_batches = int(math.ceil(float(num_fsas) / batch_size))
splits = []
for i in range(num_batches):
start = i * batch_size
end = min(start + batch_size, num_fsas)
splits.append((start, end))

ans = []
for start, end in splits:
indexes = torch.arange(start, end).to(b_to_a_map)

fsas = k2.index(b_fsas, indexes)
b_to_a = k2.index(b_to_a_map, indexes)
path_lats = k2.intersect_device(a_fsas,
fsas,
b_to_a_map=b_to_a,
sorted_match_a=sorted_match_a)
ans.append(path_lats)

return k2.append(ans)


def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa,
path_to_seq_map: torch.Tensor) -> torch.Tensor:
'''Compute AM scores of n-best lists (represented as word_fsas).
Expand Down Expand Up @@ -44,10 +86,10 @@ def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa,
del inverted_lats.aux_labels
inverted_lats = k2.arc_sort(inverted_lats)

am_path_lats = k2.intersect_device(inverted_lats,
word_fsas_with_epsilon_loops,
b_to_a_map=path_to_seq_map,
sorted_match_a=True)
am_path_lats = _intersect_device(inverted_lats,
word_fsas_with_epsilon_loops,
b_to_a_map=path_to_seq_map,
sorted_match_a=True)

# NOTE: `k2.connect` and `k2.top_sort` support only CPU at present
am_path_lats = k2.top_sort(k2.connect(am_path_lats.to('cpu'))).to(device)
Expand Down Expand Up @@ -145,10 +187,10 @@ def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa,

# Now compute lm_scores
b_to_a_map = torch.zeros_like(path_to_seq_map)
lm_path_lats = k2.intersect_device(G,
word_fsas_with_epsilon_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True)
lm_path_lats = _intersect_device(G,
word_fsas_with_epsilon_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True)
lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device)
lm_scores = lm_path_lats.get_tot_scores(True, True)

Expand Down

0 comments on commit ed4c74a

Please sign in to comment.