Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implements RNNT+MMI #1030

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Only allow states sampled on final frame to be final state; add unnor…
…malized rnnt loss
  • Loading branch information
pkufool committed Aug 17, 2022
commit 5dc671fbab23ef5eec460003ec6b696764bc8ec0
87 changes: 54 additions & 33 deletions k2/csrc/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1995,6 +1995,7 @@ FsaVec GenerateDenominatorLattice(Ragged<int32_t> &sampled_paths,
Ragged<int32_t> &frame_ids,
Ragged<int32_t> &left_symbols,
Ragged<float> &sampling_probs,
Array1<int32_t> &boundary,
int32_t vocab_size,
int32_t context_size,
Array1<int32_t> *arc_map) {
Expand All @@ -2009,6 +2010,7 @@ FsaVec GenerateDenominatorLattice(Ragged<int32_t> &sampled_paths,
K2_DCHECK_EQ(sampled_paths.NumElements(),
left_symbols.NumElements() * context_size);
K2_DCHECK_EQ(sampled_paths.NumElements(), sampling_probs.NumElements());
K2_DCHECK_EQ(sampled_paths.TotSize(0), boundary.Dim());
for (int32_t i = 0; i < 3; ++i) {
K2_DCHECK_EQ(sampled_paths.TotSize(i), frame_ids.TotSize(i));
K2_DCHECK_EQ(sampled_paths.TotSize(i), left_symbols.TotSize(i));
Expand Down Expand Up @@ -2123,17 +2125,20 @@ FsaVec GenerateDenominatorLattice(Ragged<int32_t> &sampled_paths,
us_idx0 = us_row_ids1_data[ss_idx0x],
us_idx0_next_minus_1 = us_row_ids1_data[ss_idx0x_next - 1],
num_unique_states = us_idx0_next_minus_1 - us_idx0 + 1;
// Plus 2 here, because we need a super dest_state for the last sampled
// symbol of each path, and a final state needed by k2.
num_states_for_seqs_data[idx0] = num_unique_states + 2;
// Plus 3 here, because we need a super dest_state for the states sampled
// on the last frame (this dest_state will point to the final state),
// a fake super dest_state for the last states of linear paths that
// are not sampled on the last frames (this fake dest_state will be
// removed by connect operation), and a final state needed by k2.
num_states_for_seqs_data[idx0] = num_unique_states + 3;
});

ExclusiveSum(num_states_for_seqs, &num_states_for_seqs);
RaggedShape seqs_to_states_shape = RaggedShape2(
&num_states_for_seqs, nullptr, -1);
int32_t num_merged_states = seqs_to_states_shape.NumElements();

K2_CHECK_EQ(unique_states_shape.RowSplits(1).Dim() - 1 + num_seqs * 2,
K2_CHECK_EQ(unique_states_shape.RowSplits(1).Dim() - 1 + num_seqs * 3,
num_merged_states);

// Plus 1 here because we will apply ExclusiveSum on this array.
Expand All @@ -2153,15 +2158,17 @@ FsaVec GenerateDenominatorLattice(Ragged<int32_t> &sampled_paths,
int32_t idx0 = sts_row_ids1_data[idx01],
idx0x_next = sts_row_splits1_data[idx0 + 1],
num_arcs = 0;
// The final state for each sequence.
// The final arc for each sequence.
if (idx01 == idx0x_next - 2) num_arcs = 1;
if (idx01 < idx0x_next - 2) {
// Minus idx0 * 2, because we add extra two states for each sequence.
int32_t us_idx0 = idx01 - idx0 * 2,
if (idx01 < idx0x_next - 3) {
// Minus idx0 * 3, because we add extra three states for each sequence.
int32_t us_idx0 = idx01 - idx0 * 3,
us_idx0x = us_row_splits1_data[us_idx0],
us_idx0x_next = us_row_splits1_data[us_idx0 + 1];
num_arcs = us_idx0x_next - us_idx0x;
}
// idx01 == idx0x_next - 3 (i.e. the fake super dest_state) and
// idx01 == idx0x_next -1 (i.e. the final state) don't have arcs.
num_arcs_for_states_data[idx01] = num_arcs;
});

Expand All @@ -2185,6 +2192,7 @@ FsaVec GenerateDenominatorLattice(Ragged<int32_t> &sampled_paths,
*arcs_shape_row_splits1_data = arcs_shape.RowSplits(1).Data(),
*arcs_shape_row_ids2_data = arcs_shape.RowIds(2).Data(),
*states_row_ids2_data = states.RowIds(2).Data(),
*boundary_data = boundary.Data(),
*ss_row_ids1_data = sorted_states.RowIds(1).Data();
const float *sampling_probs_data = sampling_probs.values.Data();
Array1<Arc> arcs(c, num_arcs);
Expand Down Expand Up @@ -2233,37 +2241,50 @@ FsaVec GenerateDenominatorLattice(Ragged<int32_t> &sampled_paths,

arc.score = -logf(1 - powf(1 - sampling_prob, repeat_num));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only include the "predictor" head output in C++ part, the other two scores (i.e. hybrid output and lm_output) will add on python part, it would be easier to enable autograd for hybrid output.


// Final state of the last sequence, it will point to the added super
// dest_state.
K2_DCHECK_LT(frame_ids_data[states_idx012], boundary_data[idx0]);

int32_t idx0x_next = arcs_shape_row_splits1_data[idx0 + 1];

// Handle the final state of last sequence.
if (states_idx012 == num_states - 1) {
int32_t idx0x_next = arcs_shape_row_splits1_data[idx0 + 1];
arc.dest_state = idx0x_next - idx0x - 2;
// If current state is on final frame, it will point to the added
// super dest_state.
if (frame_ids_data[states_idx012] == boundary_data[idx0] - 1) {
arc.dest_state = idx0x_next - idx0x - 2;
} else {
// point to the fake added dest_state.
arc.dest_state = idx0x_next - idx0x - 3;
}
} else {
// states_idx01 is path index
int32_t states_idx01 = states_row_ids2_data[states_idx012],
states_idx01_next =
states_row_ids2_data[states_idx012 + 1],
frame_id = frame_ids_data[states_idx012],
frame_id_next = frame_ids_data[states_idx012 + 1];
// The first condition means this is the final state of each
// sequence.
// The second condition means we reach final frame at this state,
// the next state will be a start state of another path.
// So, this state points to the added super dest_state.
if (states_idx01 != states_idx01_next ||
(states_idx01 == states_idx01_next &&
frame_id_next < frame_id)) {
int32_t idx0x_next =
arcs_shape_row_splits1_data[idx0 + 1];
arc.dest_state = idx0x_next - idx0x - 2;
states_row_ids2_data[states_idx012 + 1];
if (states_idx01 != states_idx01_next) {
// If current state is on final frame, it will point to the added
// super dest_state.
if (frame_ids_data[states_idx012] == boundary_data[idx0] - 1) {
arc.dest_state = idx0x_next - idx0x - 2;
} else {
// point to the fake added dest_state.
arc.dest_state = idx0x_next - idx0x - 3;
}
} else {
// states_idx012 + 1 is the index of original consecutive state.
// "ss" is short for "sorted states"
// "us" is short for "unique states".
int32_t ss_idx01_next =
sorted_states_old2new_data[states_idx012 + 1],
us_idx0_next = us_row_ids1_data[ss_idx01_next];
arc.dest_state = us_idx0_next + 2 * idx0 - idx0x;
// If current state is on final frame, it will point to the added
// super dest_state.
if (frame_ids_data[states_idx012] == boundary_data[idx0] - 1 &&
frame_ids_data[states_idx012 + 1] != boundary_data[idx0] - 1) {
arc.dest_state = idx0x_next - idx0x - 2;
} else {
// states_idx012 + 1 is the index of original consecutive state.
// "ss" is short for "sorted states"
// "us" is short for "unique states".
int32_t ss_idx01_next =
sorted_states_old2new_data[states_idx012 + 1],
us_idx0_next = us_row_ids1_data[ss_idx01_next];
// Plus 3 * idx0, because we add 3 state for each sequence
arc.dest_state = us_idx0_next + 3 * idx0 - idx0x;
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,7 @@ void Reverse(FsaVec &src, FsaVec *dest, Array1<int32_t> *arc_map = nullptr);
* elements MUST satisfy `0 <= value < vocab_size`.
* @param [in] sampling_probs It contains the probabilities of sampling each
* symbol, which has the same shape as sampled_paths.
* @param [in] boundary It contains the number of frames for each sequence.
* @param [in] vocab_size The vocabulary size.
* @param [in] context_size The number of left symbols.
* @param [out] arc_map For each arc in the return Fsa, gives the orignal
Expand All @@ -905,6 +906,7 @@ FsaVec GenerateDenominatorLattice(Ragged<int32_t> &sampled_paths,
Ragged<int32_t> &frame_ids,
Ragged<int32_t> &left_symbols,
Ragged<float> &sampling_probs,
Array1<int32_t> &boundary,
int32_t vocab_size,
int32_t context_size,
Array1<int32_t> *arc_map);
Expand Down
9 changes: 2 additions & 7 deletions k2/csrc/fsa_algo_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1414,16 +1414,11 @@ TEST(FsaAlgo, TestGenerateDenominatorLattice) {
" [ 0.2 0.2 0.2 0.1 0.2 0.1 0.2 0.2 ] "
" [ 0.1 0.1 0.2 0.2 0.3 0.3 0.3 0.3 ] ] "
"]");
Ragged<float> path_scores(c, "[ [ [ 1 1 1 1 1 1 1 1 ] "
" [ 1 2 2 1 2 2 1 2 ] "
" [ 1 1 1 3 3 3 3 3 ] ] "
" [ [ 1 1 1 1 1 1 1 1 ] "
" [ 1 2 1 2 2 2 2 2 ] "
" [ 1 1 1 2 3 3 3 3 ] ] ]");
Array1<int32_t> boundary(c, "[ 3 4 ]");

Array1<int32_t> arc_map;
FsaVec lattice = GenerateDenominatorLattice(
sampled_paths, frame_ids, left_symbols, sampling_probs,
sampled_paths, frame_ids, left_symbols, sampling_probs, boundary,
10 /*vocab_size*/, 2 /*context_size*/, &arc_map);
K2_LOG(INFO) << arc_map;
K2_LOG(INFO) << lattice;
Expand Down
6 changes: 4 additions & 2 deletions k2/python/csrc/torch/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -770,21 +770,23 @@ static void PybindGenerateDenominatorLattice(py::module &m) {
"generate_denominator_lattice",
[](RaggedAny &sampled_paths, RaggedAny &frame_ids,
RaggedAny &left_symbols, RaggedAny &sampling_probs,
int32_t vocab_size, int32_t context_size)
torch::Tensor &boundary, int32_t vocab_size, int32_t context_size)
-> std::pair<FsaVec, torch::Tensor> {
DeviceGuard guard(sampled_paths.any.Context());
Array1<int32_t> arc_map;
Array1<int32_t> boundary_array = FromTorch<int32_t>(boundary);
FsaVec lattice = GenerateDenominatorLattice(
sampled_paths.any.Specialize<int32_t>(),
frame_ids.any.Specialize<int32_t>(),
left_symbols.any.Specialize<int32_t>(),
sampling_probs.any.Specialize<float>(),
boundary_array,
vocab_size, context_size, &arc_map);
auto arc_map_tensor = ToTorch(arc_map);
return std::make_pair(lattice, arc_map_tensor);
},
py::arg("sampled_paths"), py::arg("frame_ids"), py::arg("left_symbols"),
py::arg("sampling_probs"), py::arg("vocab_size"),
py::arg("sampling_probs"), py::arg("boundary"), py::arg("vocab_size"),
py::arg("context_size"));
}

Expand Down
1 change: 1 addition & 0 deletions k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from .rnnt_loss import get_rnnt_logprobs_smoothed
from .rnnt_loss import get_rnnt_prune_ranges
from .rnnt_loss import rnnt_loss
from .rnnt_loss import rnnt_loss_for_numerator
from .rnnt_loss import rnnt_loss_pruned
from .rnnt_loss import rnnt_loss_simple
from .rnnt_loss import rnnt_loss_smoothed
Expand Down
6 changes: 5 additions & 1 deletion k2/python/k2/fsa_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,7 @@ def generate_denominator_lattice(
left_symbols: torch.Tensor,
sampling_probs: torch.Tensor,
path_scores: torch.Tensor,
boundary: torch.Tensor,
vocab_size: int,
context_size: int,
) -> Fsa:
Expand All @@ -1415,6 +1416,8 @@ def generate_denominator_lattice(
It contains the scores of each sampled symbol, which has a same shape as
sampled_paths. It might contain the output of hybrid head and the extra
language model output. Note: Autograd is supported for this tensor.
boundary:
It contains the number of frames for each sequence.
vocab_size:
The vocabulary size.
context_size:
Expand All @@ -1425,13 +1428,14 @@ def generate_denominator_lattice(
frame_ids=k2.RaggedTensor(frame_ids),
left_symbols=k2.RaggedTensor(left_symbols),
sampling_probs=k2.RaggedTensor(sampling_probs),
boundary=boundary,
vocab_size=vocab_size,
context_size=context_size,
)
lattice = Fsa(ragged_arc)
a_value = getattr(lattice, "scores")
# Enable autograd for path_scores
b_value = index_select(path_scores.flatten(), arc_map)
value = a_value + b_value
value = b_value - a_value
setattr(lattice, "scores", value)
return lattice
Loading