From 66fb3516f0595aab0335d6f9594428b8817736ee Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Wed, 4 Jan 2023 16:06:32 +0800 Subject: [PATCH] fix lattice length --- k2/csrc/rnnt_decode.cu | 187 ++++++++++++++++++++++++++-- k2/csrc/rnnt_decode.h | 75 ++++++++++- k2/python/csrc/torch/rnnt_decode.cu | 16 +++ k2/python/k2/rnnt_decode.py | 10 +- 4 files changed, 276 insertions(+), 12 deletions(-) diff --git a/k2/csrc/rnnt_decode.cu b/k2/csrc/rnnt_decode.cu index 1224d7b58..1e3d16331 100644 --- a/k2/csrc/rnnt_decode.cu +++ b/k2/csrc/rnnt_decode.cu @@ -247,7 +247,7 @@ RaggedShape RnntDecodingStreams::ExpandArcs() { return unpruned_arcs_shape; } -Renumbering RnntDecodingStreams::DoFisrtPassPruning( +Renumbering RnntDecodingStreams::DoFirstPassPruning( RaggedShape &unpruned_arcs_shape, const Array2 &logprobs) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(unpruned_arcs_shape.NumAxes(), 4); @@ -438,7 +438,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { auto unpruned_arcs_shape = ExpandArcs(); // (2) Do initial pruning. - auto pass1_renumbering = DoFisrtPassPruning(unpruned_arcs_shape, logprobs); + auto pass1_renumbering = DoFirstPassPruning(unpruned_arcs_shape, logprobs); // pass1_arcs_shape has a shape of [stream][context][state][arc] auto pass1_arcs_shape = @@ -507,6 +507,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { idx01 = uas_row_ids2_data[idx012], idx0 = uas_row_ids1_data[idx01], num_graph_states = num_graph_states_data[idx0]; int64_t this_state = this_states_values_data[idx012]; + int32_t this_graph_state = this_state % num_graph_states; double this_score = this_scores_data[idx012]; // handle the implicit epsilon self-loop @@ -515,7 +516,21 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { // we assume termination symbol to be 0 here. scores_data[arc_idx] = this_score + logprobs_acc(idx01, 0); ArcInfo ai; - ai.graph_arc_idx01 = -1; + /* + Track state index for self-loop arcs. + It's lucky that type int32_t has range [-2147483648, 2147483647] + there is one more negative values than positive values in computer. + state (0) --> graph_arc_idx01 (-1) + state (1) --> graph_arc_idx01 (-2) + state (2) --> graph_arc_idx01 (-3) + state (2147483647) --> graph_arc_idx01 (-2147483648) + + Actually, super final state has no self-loop. + So definitely there are enough negative values + to represent positive state index. + */ + ai.graph_arc_idx01 = -this_graph_state - 1; + K2_CHECK_LT(ai.graph_arc_idx01, 0); ai.score = logprobs_acc(idx01, 0); ai.label = 0; arcs_data[arc_idx] = ai; @@ -526,8 +541,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; int64_t this_context_state = this_state / num_graph_states; - int32_t this_graph_state = this_state % num_graph_states, - graph_idx0x = graph_row_split1_data[this_graph_state], + int32_t graph_idx0x = graph_row_split1_data[this_graph_state], graph_idx01 = graph_idx0x + idx3 - 1; // minus 1 here as // epsilon self-loop // takes the position 0. @@ -711,9 +725,164 @@ void RnntDecodingStreams::GatherPrevFrames( } } +void RnntDecodingStreams::GetFinalArcs() { + NVTX_RANGE(K2_FUNC); + /* + This function handles last two steps of the generated lattice. + Relationship of variables in these two steps are: + + arcs: last frame arcs final arcs + states: {last frame state} ---------------> {final states} ---------> {super final state} # noqa + + Suer final state has no leaving arcs. + */ + + int32_t frames = prev_frames_.size(); + + // with shape [stream][state][arc] + auto last_frame_shape = prev_frames_[frames - 1]->shape; + + // Note: last_frame_arc_data is non-const + // The original "dest_state" attribute for each element in last_frame_arc_data + // is state index processed by function GroupStatesByContexts. + // In this function, source states in last_frame is expanded again, + // and those expanded destination states are NOT grouped to save time. + // So "dest_state" should be re-assigned to a new value. + ArcInfo *last_frame_arc_data = prev_frames_[frames - 1]->values.Data(); + const int32_t *lfs_row_ids2_data = last_frame_shape.RowIds(2).Data(), + *lfs_row_ids1_data = last_frame_shape.RowIds(1).Data(), + *lfs_row_splits2_data = last_frame_shape.RowSplits(2).Data(), + *lfs_row_splits1_data = last_frame_shape.RowSplits(1).Data(); + + const int32_t *num_graph_states_data = num_graph_states_.Data(); + const int32_t *const *graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); + const Arc *const *graphs_arcs_data = graphs_.values.Data(); + + // Name meaning of final_grpah_states: + // "final_" means it's for "final states". + // "_graph_states" means it storages state index in decoding graph. + // Though this variable could be calculated both in + // labmda_get_final_arcs_shape and lambda_populate_final_arcs, + // to save time, its calculated and cached during the former and + // used in the later. + Array1 final_graph_states(c_, last_frame_shape.NumElements()); + int32_t* final_graph_states_data = final_graph_states.Data(); + + // Calculate num_arcs for each final state. + Array1 num_final_arcs(c_, last_frame_shape.NumElements() + 1); + int32_t *num_final_arcs_data = num_final_arcs.Data(); + + K2_EVAL( + c_, last_frame_shape.NumElements(), lambda_get_final_arcs_shape, + (int32_t idx012) { + // place here to save one kernel. + num_final_arcs_data[idx012] = 0; + + int32_t idx01 = lfs_row_ids2_data[idx012], // state_idx01 + idx0 = lfs_row_ids1_data[idx01], // stream_idx0 + arc_idx0x = lfs_row_splits1_data[idx0], + arc_idx0xx = lfs_row_splits2_data[arc_idx0x], + arc_idx12 = idx012 - arc_idx0xx; + + ArcInfo& ai = last_frame_arc_data[idx012]; + + // Re-assign dest_state to a new value. + // See more detail comment at previous last_frame_arc_data definition. + ai.dest_state = arc_idx12; + + if (ai.label == -1) { + num_final_arcs_data[idx012] = 0; + // -(num_graph_states_data[idx0]) for state not expandable. + final_graph_states_data[idx012] = -(num_graph_states_data[idx0]); + return; + } + int dest_state = -1; + const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; + const Arc *graph_arcs_data = graphs_arcs_data[idx0]; + if (ai.graph_arc_idx01 < 0) { + // For implicit self-loop arcs. + dest_state = -(ai.graph_arc_idx01 + 1); + K2_CHECK_GE(dest_state, 0); + K2_CHECK_LE(dest_state, num_graph_states_data[idx0]); + } else { + // For other arcs shown in the decoding graph. + dest_state = graph_arcs_data[ai.graph_arc_idx01].dest_state; + } + K2_CHECK_GE(dest_state, 0); + + final_graph_states_data[idx012] = dest_state; + // Plus one for the implicit epsilon self-loop. + num_final_arcs_data[idx012] = graph_row_split1_data[dest_state + 1] - + graph_row_split1_data[dest_state] + 1; + }); + + ExclusiveSum(num_final_arcs, &num_final_arcs); + auto final_arcs_shape = RaggedShape2(&num_final_arcs, nullptr, -1); + final_arcs_shape = ComposeRaggedShapes(last_frame_shape, final_arcs_shape); + // [steam][state][arc][arc] --> [stream][arc][arc] + // could be viewd as [strem][final state][arc] + final_arcs_shape = RemoveAxis(final_arcs_shape, 1); + const int32_t *fas_row_ids1_data = final_arcs_shape.RowIds(1).Data(); + const int32_t *fas_row_ids2_data = final_arcs_shape.RowIds(2).Data(); + const int32_t *fas_row_splits2_data = final_arcs_shape.RowSplits(2).Data(); + + auto final_arcs = Ragged(final_arcs_shape); + ArcInfo *final_arcs_data = final_arcs.values.Data(); + K2_EVAL( + c_, final_arcs_shape.NumElements(), lambda_populate_final_arcs, + (int32_t idx012) { + const int32_t idx01 = fas_row_ids2_data[idx012], // state + idx0 = fas_row_ids1_data[idx01], // stream + idx01x = fas_row_splits2_data[idx01], + arc_idx2 = idx012 - idx01x; + + const Arc *graph_arcs_data = graphs_arcs_data[idx0]; + const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; + int32_t graph_state_idx0 = final_graph_states_data[idx01]; + + int32_t ai_graph_arc_idx01 = 0; + int32_t ai_arc_label = 0; + if (graph_state_idx0 < 0) { + /* + Could be one of following two cases: + case 1: not expandable if graph_state_idx0 == -(num_graph_states_data[idx0]) # noqa + case 2: implicit self-loop if graph_state_idx0 > -(num_graph_states_data[idx0]) # noqa + */ + K2_DCHECK_GT(graph_state_idx0, -(num_graph_states_data[idx0])); + ai_arc_label = 0; + ai_graph_arc_idx01 = -1; + } else { + // For arcs shown in decoding graph. + int32_t graph_arc_idx0x = graph_row_split1_data[graph_state_idx0]; + // arc_idx2 could be viewed as graph_arc_idx1, + // since final_arcs_shape has 3 axes where arc_idx2 is calculated, + // while decoding_graph only has 2 axes where arc_idx2 is used. + ai_graph_arc_idx01 = graph_arc_idx0x + arc_idx2; + auto graph_arc = graph_arcs_data[ai_graph_arc_idx01]; + ai_arc_label = graph_arc.label; + } + ArcInfo ai; + // ai.dest_state will be overwritted by FormatOutput + // just initialize it as -1 here + ai.dest_state = -1; + ai.graph_arc_idx01 = ai_graph_arc_idx01; + ai.score = 0.0; + ai.label = ai_arc_label; + final_arcs_data[idx012] = ai; + }); + + prev_frames_.emplace_back(std::make_shared>(final_arcs)); +} + void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, bool allow_partial, FsaVec *ofsa, Array1 *out_map) { + FormatOutput(num_frames, allow_partial, false /* is_final */, ofsa, out_map); +} + +void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, + bool allow_partial, bool is_final, + FsaVec *ofsa, Array1 *out_map) { NVTX_RANGE(K2_FUNC); K2_CHECK(!attached_) << "You can only get outputs after calling TerminateAndFlushToStreams()"; @@ -723,6 +892,10 @@ void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, GatherPrevFrames(num_frames); + if (is_final) { + GetFinalArcs(); + } + int32_t frames = prev_frames_.size(); auto last_frame_shape = prev_frames_[frames - 1]->shape; @@ -873,11 +1046,11 @@ void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, int32_t dest_state_idx012 = oarc_idx01x_next + arc_info.dest_state; arc.dest_state = dest_state_idx012 - oarc_idx0xx; - // graph_arc_idx01 == -1 means this is a implicit epsilon self-loop + // graph_arc_idx01 < 0 means this is an implicit epsilon self-loop // arc_info.label == -1 means this is the final arc before last // frame this is non-accessible arc, we set its label to 0 here to // make the generated lattice a valid k2 fsa. - if (arc_info.graph_arc_idx01 == -1 || arc_info.label == -1) { + if (arc_info.graph_arc_idx01 <= -1 || arc_info.label == -1) { arc.label = 0; arc_info.graph_arc_idx01 = -1; } else { diff --git a/k2/csrc/rnnt_decode.h b/k2/csrc/rnnt_decode.h index 9224ca807..58ff1e3bf 100644 --- a/k2/csrc/rnnt_decode.h +++ b/k2/csrc/rnnt_decode.h @@ -94,10 +94,25 @@ struct RnntDecodingConfig { struct ArcInfo { // The arc-index within the RnntDecodingStream::graph that corresponds to this - // arc, or -1 if this arc is a "termination symbol" (these do not appear in - // the graph). + // arc if non-negative. + // There is an implicit self-loop arc for each state, which are represented + // by -(state_index + 1), see following comments of dest_state_in_graph. int32_t graph_arc_idx01; + // Note: + // 1. To save memory, value of this variable is calculated + // from graph_arc_idx01. + // 2. It is differnt from variable dest_state. + // dest_state_in_graph is the destination state index in decoding graph. + // dest_state below is the state index in "generated lattice". + // There are two kinds of arcs in decoding graph: + // 1. Implicit self-loop arcs, dest_state of these arcs are calculated + // with -(graph_arc_idx01 + 1). + // (Note, graph_arc_idx01 is negative for these arcs) + // 2. Other arcs shown in decoding graph, dest_state of these arcs are + // calculated with graph_arcs_data[ai.graph_arc_idx01].dest_state + // int32_t dest_state_in_graph; + // The score on the arc; contains both the graph score (if any) and the score // from the RNN-T joiner. float score; @@ -199,6 +214,38 @@ class RnntDecodingStreams { void FormatOutput(const std::vector &num_frames, bool allow_partial, FsaVec *ofsa, Array1 *out_map); + /* + Generate the lattice. + Note: Almost the same with previous overloaded version, + except for an extra `is_final` argument. + + Note: The prev_frames_ only contains decoded by current object, in order to + generate the lattice we will first gather all the previous frames from + individual streams. + + @param [in] num_frames A vector containing the number of frames we want + to gather for each stream (note: the frames we have + ever received). + It MUST satisfy `num_frames.size() == num_streams_`, and + `num_frames[i] <= srcs_[i].prev_frames.size()`. + @param [in] allow_partial If true and there is no final state active, + we will treat all the states on the last frame + to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. + @param [in] is_final If true, function GetFinalArcs() will be called. + If false, the same with previous overloaded version. + @param [out] ofsa The output lattice will write to here, its num_axes + equals to 3, will be re-allocated. + @param [out] out_map It is an Array1 with Dim() equals to + ofsa.NumElements() containing the idx01 into the graph of + each individual streams, mapping current arc in ofsa to + original decoding graphs. It may contain -1 which means + this arc is a "termination symbol". + */ + void FormatOutput(const std::vector &num_frames, bool allow_partial, + bool is_final, FsaVec *ofsa, Array1 *out_map); + /* Terminate the decoding process of current RnntDecodingStreams object, it will update the states & scores of each individual stream and split & @@ -261,8 +308,30 @@ class RnntDecodingStreams { @return Return the renumbering object indicating which arc will be kept. */ - Renumbering DoFisrtPassPruning(RaggedShape &unprund_arcs_shape, + Renumbering DoFirstPassPruning(RaggedShape &unprund_arcs_shape, const Array2 &logprobs); + + /* + Get final arcs when last frame is received, i.e. passing is_final=True to + function `FormatOutput`. + Comparing with openfst, a valid fsa in k2 needs arcs with label==-1 + pointing to a super final state. This function is handling these arcs. + See detail of the problem solved by this function at + https://github.com/k2-fsa/k2/pull/1089 + + If we name varialbes for last two steps of a lattice as: + arcs: last frame arcs final arcs + states: {last frame state} ---------------> {final states} ---------> {super final state} + + This function mainly do following steps: + 1. get last_frame from prev_frames_ + 2. expand last frame and get final states + 3. re-assign dest state of last frame arcs to final states + 4. populate final arcs + 5. append final arcs to prev_frames_ + */ + void GetFinalArcs(); + /* Group states by contexts. diff --git a/k2/python/csrc/torch/rnnt_decode.cu b/k2/python/csrc/torch/rnnt_decode.cu index 0a81deaed..7bda90397 100644 --- a/k2/python/csrc/torch/rnnt_decode.cu +++ b/k2/python/csrc/torch/rnnt_decode.cu @@ -152,6 +152,22 @@ static void PybindRnntDecodingStreams(py::module &m) { torch::Tensor out_map_tensor = ToTorch(out_map); return std::make_pair(ofsa, out_map_tensor); }); + + streams.def("format_output", + [](PyClass &self, std::vector &num_frames, + bool allow_partial, bool is_final) + -> std::pair { + DeviceGuard guard(self.Context()); + FsaVec ofsa; + Array1 out_map; + self.FormatOutput(num_frames, + allow_partial, + is_final, + &ofsa, + &out_map); + torch::Tensor out_map_tensor = ToTorch(out_map); + return std::make_pair(ofsa, out_map_tensor); + }); } } // namespace k2 diff --git a/k2/python/k2/rnnt_decode.py b/k2/python/k2/rnnt_decode.py index 95f6aa8ab..7d2fe559c 100644 --- a/k2/python/k2/rnnt_decode.py +++ b/k2/python/k2/rnnt_decode.py @@ -149,7 +149,8 @@ def terminate_and_flush_to_streams(self) -> None: def format_output( self, num_frames: List[int], - allow_partial: bool = False + allow_partial: bool = False, + is_final: bool = False, ) -> Fsa: """ Generate the lattice Fsa currently got. @@ -173,6 +174,11 @@ def format_output( If false, we only care about the real final state in the decoding graph on the last frame when generating lattice. Default False. + is_final: + If true, function GetFinalArcs() will be called. + See detail of the problem solved by GetFinalArcs() at + https://github.com/k2-fsa/k2/pull/1089 + Returns: Return the lattice Fsa with all the attributes propagated. @@ -181,7 +187,7 @@ def format_output( assert len(num_frames) == self.num_streams ragged_arcs, out_map = self.streams.format_output( - num_frames, allow_partial + num_frames, allow_partial, is_final, ) fsa = Fsa(ragged_arcs)