diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index 221c2fbc9..a97e291e8 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -71,8 +71,10 @@ set(context_srcs timer.cu top_sort.cu utils.cu + nbest.cu ) + if(K2_USE_PYTORCH) list(APPEND context_srcs pytorch_context.cu) else() @@ -148,6 +150,7 @@ set(cuda_test_srcs log_test.cu macros_test.cu math_test.cu + nbest_test.cu nvtx_test.cu pinned_context_test.cu ragged_shape_test.cu diff --git a/k2/csrc/array_ops.h b/k2/csrc/array_ops.h index 1630ec9a2..d8c73e8d7 100644 --- a/k2/csrc/array_ops.h +++ b/k2/csrc/array_ops.h @@ -83,7 +83,7 @@ void ExclusiveSum(const Array1 &src, Array1 *dest) { ExclusiveSum(src.Context(), dest_dim, src.Data(), dest->Data()); } -/* wrapper for the ExclusiveSum above. Will satisfy +/* wrapper for the ExclusiveSum above (returns array with same dim as src). Will satisfy ans[i] = sum_{j=0}^{i-1} src[j] for i > 0. ans[0] is always 0. */ diff --git a/k2/csrc/array_ops_inl.h b/k2/csrc/array_ops_inl.h index a23b6e2b9..d24e9a58c 100644 --- a/k2/csrc/array_ops_inl.h +++ b/k2/csrc/array_ops_inl.h @@ -260,6 +260,7 @@ void ExclusiveSumDeref(Array1 &src, Array1 *dest) { if (dest_dim == src_dim + 1) { const RegionPtr ®ion = src.GetRegion(); ssize_t byte_offset = static_cast(src.ByteOffset()); + // If this fails: you must allocate one extra element past the end of src! K2_CHECK_GE(region->num_bytes - byte_offset, dest_dim * src.ElementSize()); } internal::PtrPtr src_data = internal::PtrPtr(src.Data()); @@ -285,6 +286,7 @@ void ExclusiveSum(const Array2 &src, Array2 *dest, int32_t axis) { if (dest_major_dim == src_major_dim + 1) { const RegionPtr ®ion = src.GetRegion(); ssize_t byte_offset = static_cast(src.ByteOffset()); + // If this fails: you must allocate one extra element past the end of src! K2_CHECK_GE(region->num_bytes - byte_offset, (src_major_dim * src_minor_dim + 1) * src.ElementSize()); } diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu new file mode 100644 index 000000000..fd33281ae --- /dev/null +++ b/k2/csrc/nbest.cu @@ -0,0 +1,635 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey + * Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "k2/csrc/nbest.h" + +// This is not really a CUDA file but for build-system reasons I'm currently +// leaving it with the .cu extension. + +namespace k2 { + +template +inline bool Leq(T a1, T a2, T b1, T b2) { + // lexicographic order for pairs, used in CreateSuffixArray() + return(a1 < b1 || a1 == b1 && a2 <= b2); +} +template +inline bool Leq(T a1, T a2, T a3, T b1, T b2, T b3) { + // lexicographic order for triples, used in CreateSuffixArray() + return(a1 < b1 || a1 == b1 && Leq(a2, a3, b2, b3)); +} + +/* + Helper function for CreateSuffixArray(). + + Stably sorts a[0..n-1] to b[0..n-1] with keys in 0..K from r; + i.e. the values in a are interpreted as indexes into the array + `r` and the values in `r` are used for comparison, so that + at exit, r[b[i]] <= r[b[i+1]]. +*/ +template +static void RadixPass(const T* a, T* b, const T* r, T n, T K) { + T* c = new T[K + 1]; // counter array + for (T i = 0; i <= K; i++) c[i] = 0; // reset counters + for (T i = 0; i < n; i++) c[r[a[i]]]++; // count occurrences + for (T i = 0, sum = 0; i <= K; i++) { // exclusive prefix sums + T t = c[i]; c[i] = sum; sum += t; + } + for (T i = 0; i < n; i++) b[c[r[a[i]]]++] = a[i]; // sort + delete [] c; +} + +// See documentation in nbest.h, where we use different names +// for the arguments (here, we leave the names the same as in +// https://algo2.iti.kit.edu/documents/jacm05-revised.pdf. +template +void CreateSuffixArray(const T* text, T n, T K, T* SA) { + // assert(text[0] <= text[n-1]); // spot check that termination symbol is + // larger than other symbols; + // <= in case n==1. + if (n == 1) { // The paper's code didn't seem to handle n == 1 correctly. + SA[0] = 0; + return; + } + T n0 = (n + 2) / 3, n1 = (n+1) / 3, n2 = n / 3, n02 = n0 + n2; + T *R = new T[n02 + 3]; R[n02] = R[n02 + 1] = R[n02 + 2] = 0; + T *SA12 = new T[n02 + 3]; SA12[n02] = SA12[n02 + 1] = SA12[n02 + 2] = 0; + T *R0 = new T[n0]; + T *SA0 = new T[n0]; + //******* Step 0: Construct sample ******** + // generate positions of mod 1 and mod 2 suffixes + // the "+(n0-n1)" adds a dummy mod 1 suffix if n%3 == 1 + for (T i = 0, j = 0; i < n + (n0 - n1); i++) if (i % 3 != 0) R[j++] = i; + //******* Step 1: Sort sample suffixes ******** + // lsb radix sort the mod 1 and mod 2 triples + RadixPass(R, SA12, text + 2, n02, K); + RadixPass(SA12, R , text + 1, n02, K); + RadixPass(R, SA12, text, n02, K); + + // find lexicographic names of triples and + // write them to correct places in R + T name = 0, c0 = -1, c1 = -1, c2 = -1; + for (T i = 0; i < n02; i++) { + if (text[SA12[i]] != c0 || text[SA12[i] + 1] != c1 || + text[SA12[i] + 2] != c2) { + name++; + c0 = text[SA12[i]]; + c1 = text[SA12[i] + 1]; + c2 = text[SA12[i] + 2]; + } + if (SA12[i] % 3 == 1) { R[SA12[i] / 3] = name; } // write to R1 + else { R[SA12[i] / 3 + n0] = name; } // write to R2 + } + // recurse if names are not yet unique + if (name < n02) { + CreateSuffixArray(R, n02, name, SA12); + // store unique names in R using the suffix array + for (T i = 0; i < n02; i++) R[SA12[i]] = i + 1; + } else // generate the suffix array of R directly + for (T i = 0; i < n02; i++) SA12[R[i] - 1] = i; + //******* Step 2: Sort nonsample suffixes ******** + // stably sort the mod 0 suffixes from SA12 by their first character + for (T i = 0, j = 0; i < n02; i++) + if (SA12[i] < n0) R0[j++] = 3 * SA12[i]; + RadixPass(R0, SA0, text, n0, K); + //******* Step 3: Merge ******** + // merge sorted SA0 suffixes and sorted SA12 suffixes + for (T p = 0, t = n0 - n1, k = 0; k < n; k++) { + // i is pos of current offset 12 suffix + T i = (SA12[t] < n0 ? SA12[t] * 3 + 1 : (SA12[t] - n0) * 3 + 2); + T j = SA0[p]; // pos of current offset 0 suffix + if (SA12[t] < n0 ? // different compares for mod 1 and mod 2 suffixes + Leq(text[i], R[SA12[t] + n0], text[j], R[j / 3]) : + Leq(text[i], text[i + 1], R[SA12[t] - n0 + 1], text[j], + text[j + 1], R[j / 3 + n0])) { // suffix from SA12 is smaller + SA[k] = i; t++; + if (t == n02) // done --- only SA0 suffixes left + for (k++; p < n0; p++, k++) SA[k] = SA0[p]; + } else { // suffix from SA0 is smaller + SA[k] = j; p++; + if (p == n0) // done --- only SA12 suffixes left + for (k++; t < n02; t++, k++) + SA[k] = (SA12[t] < n0 ? SA12[t] * 3 + 1 : (SA12[t] - n0) * 3 + 2); + } + } + delete [] R; delete [] SA12; delete [] SA0; delete [] R0; +} + +// Instantiate template for int32_t and int16_t +template void CreateSuffixArray(const int32_t* text, int32_t n, + int32_t K, int32_t* SA); +template void CreateSuffixArray(const int16_t* text, int16_t n, + int16_t K, int16_t* SA); + +// This implements Kasai's algorithm, as summarized here +// https://people.csail.mit.edu/jshun/lcp.pdf +// (Note: there seem to be some wrong implementations of +// Kasai's algorithm online). +template +void CreateLcpArray(const T *array, + const T *suffix_array, + T seq_len, + T *lcp_array) { + Array1 inv_suffix_array(GetCpuContext(), seq_len); + T *inv_suffix_array_data = inv_suffix_array.Data(); + for (T i = 0; i < seq_len; i++) { + inv_suffix_array_data[suffix_array[i]] = i; + } + T k = 0; + if (seq_len > 0) + lcp_array[0] = 0; + + for (T i = 0; i < seq_len; ++i) { + T cur_rank = inv_suffix_array[i]; + if (cur_rank != 0) { + T j = suffix_array[cur_rank - 1]; + while (array[i + k] == array[j + k]) + ++k; + lcp_array[cur_rank] = k; + if (k > 0) + --k; + } + } +} + +// Instantiate template for int32_t and int16_t +template void CreateLcpArray(const int32_t *array, const int32_t *suffix_array, + int32_t seq_len, int32_t *lcp_array); +template void CreateLcpArray(const int16_t *array, const int16_t *suffix_array, + int16_t seq_len, int16_t *lcp_array); + +template +void CreateLcpIntervalArray(ContextPtr c, + T seq_len, + T *lcp_array, + Array1 > *lcp_intervals, + Array1 *leaf_parent_intervals) { + *lcp_intervals = Array1 >(c, seq_len); + LcpInterval *lcp_intervals_data = lcp_intervals->Data(); + + Array1 intervals_order(c, seq_len); + T *intervals_order_data = intervals_order.Data(); + + Array1 leaf_parent(c, seq_len); + T *leaf_parent_data = leaf_parent.Data(); + + // This is the stack from Algorithm 1 and Algorithm 2 of + // http://www.mi.fu-berlin.de/wiki/pub/ABI/RnaSeqP4/enhanced-suffix-array.pdf + // (you can refer to the papers mentioned in the documentation in nbest.h + // if this link goes dead). + // + // The 'begin', 'last' and 'lcp' members correspond to the 'lb', 'rb' and + // 'lcp' members mentioned there; the 'parent' member is used temporarily + // on the stack to refer to the index of this LcpInterval in + // `lcp_intervals_data`, i.e. it can be interpreted as a 'self' pointer. + std::vector > stack; + + // A separate stack, of leaves of suffix tree; we maintain this so that + // we can assign the `leaf_parent` array. + std::vector leaf_stack; + + // lcp=0; begin=0; last=undefined; self=0 (interpreting the 'parent' member + // as index-of-self + // Will always store the next free index into `lcp_intervals_data` + T next = 0; + // Will always store the next free index into `intervals_order_data`; + // this is an ordering of the indexes into `lcp_intervals_data` that + // corresponds to depth-first search. + T dfs_next = 0; + T last_interval = -1; // Will store an index into `lcp_intervals`; this + // comes from Algorithm 2 mentioned above + stack.push_back({0, 0, T(seq_len - 1), next++ }); + // We are using a numbering in which the terminating symbol $ is included + // in the array length, which is why we do "i < seq_len" and not + // "i <= seq_len" as in + // http://www.mi.fu-berlin.de/wiki/pub/ABI/RnaSeqP4/enhanced-suffix-array.pdf + for (T i = 1; i < seq_len; ++i) { + T lb = i - 1, lcp_array_i = lcp_array[i]; + leaf_stack.push_back(lb); + + while (lcp_array_i < stack.back().lcp) { + last_interval = stack.back().parent; // actually, the .parent field + // currently represents 'self', + // i.e. the index of the + // lcp-interval stack.back(). + T last_interval_dfsorder = dfs_next++; + lb = stack.back().lb; + while (!leaf_stack.empty() && leaf_stack.back() >= lb) { + leaf_parent_data[leaf_stack.back()] = last_interval_dfsorder; + leaf_stack.pop_back(); + } + // process(last_interval): + lcp_intervals_data[last_interval_dfsorder] = stack.back(); + // Previously tried doing: + // stack.back().rb = i - 1; + // a bit further above, but hit some kind of compiler problem, + // the assignment had no effect (back() is supposed to return a + // reference). + lcp_intervals_data[last_interval_dfsorder].rb = i - 1; + intervals_order_data[last_interval] = last_interval_dfsorder; + stack.pop_back(); + if (lcp_array_i <= stack.back().lcp) { + // lcp_intervals_data[last_interval_dfsorder].parent represents + // the parent of `last_interval`; `stack.back().parent` currently + // represents the intended position of stack.back() itself, + // not of its parent. + lcp_intervals_data[last_interval_dfsorder].parent = + stack.back().parent; + last_interval = -1; + } + } + if (lcp_array_i > stack.back().lcp) { + if (last_interval >= 0) { + lcp_intervals_data[intervals_order_data[last_interval]].parent = next; + last_interval = -1; + } + stack.push_back({lcp_array_i, lb, -1, next++}); + } + } + assert(stack.size() == 1); + T top_node_dfsorder = dfs_next++; + lcp_intervals_data[top_node_dfsorder] = stack.back(); + lcp_intervals_data[top_node_dfsorder].parent = -1; + intervals_order_data[0] = top_node_dfsorder; + leaf_stack.push_back(seq_len - 1); + while (!leaf_stack.empty()) { + leaf_parent_data[leaf_stack.back()] = top_node_dfsorder; + leaf_stack.pop_back(); + } + assert(dfs_next == next); + for (T i = 0; i + 1 < next; i++) { + // for each lpc-interval, except the last (top) node which has -1 as its + // parent field.. Change from pushing order (order in which we pushed them + // onto the stack) to dfs post order (order in which they were popped). + lcp_intervals_data[i].parent = + intervals_order_data[lcp_intervals_data[i].parent]; + } + + *lcp_intervals = lcp_intervals->Range(0, next); + for (T i = 0; i < next; i++) + intervals_order_data[i] = i; // We output in dfs post order now.. will + // remove this output arg. + if (leaf_parent_intervals) + *leaf_parent_intervals = leaf_parent; +} + +// Instantiate template +template +void CreateLcpIntervalArray(ContextPtr c, + int32_t seq_len, + int32_t *lcp_array, + Array1 > *lcp_intervals, + Array1 *leaf_parent_intervals); +template +void CreateLcpIntervalArray(ContextPtr c, + int16_t seq_len, + int16_t *lcp_array, + Array1 > *lcp_intervals, + Array1 *leaf_parent_intervals); + +template +void FindTightestNonemptyIntervals(T seq_len, + Array1 > *lcp_intervals, + Array1 *counts_exclusive_sum, + Array1 *leaf_parent_intervals) { + ContextPtr c = lcp_intervals->Context(); + K2_CHECK(counts_exclusive_sum->Dim() == seq_len + 1); + K2_CHECK(leaf_parent_intervals->Dim() == seq_len); + + const LcpInterval *lcp_intervals_data = lcp_intervals->Data(); + const T *counts_exclusive_sum_data = counts_exclusive_sum->Data(); + int32_t num_intervals = lcp_intervals->Dim(); + // `tightest_nonempty_intervals` gives, for each interval + // 0 <= i < num_intervals, the index j >= i of the tightest enclosing + // interval that has a nonzero count. As a special case, if all counts + // are zero, it will return the top (last) interval. + Array1 tightest_nonempty_interval(c, num_intervals); + T *tightest_nonempty_interval_data = tightest_nonempty_interval.Data(); + for (T i = num_intervals - 1; i >= 0; --i) { + T j; + LcpInterval cur_interval = lcp_intervals_data[i]; + if (cur_interval.parent < 0 || // top node + counts_exclusive_sum_data[cur_interval.rb + 1] > + counts_exclusive_sum_data[cur_interval.lb]) { + j = i; + } else { + // j > i, we will have already set tightest_nonempty_interval_data + // at this location. + j = tightest_nonempty_interval_data[cur_interval.parent]; + } + tightest_nonempty_interval_data[i] = j; + } + T *leaf_parent_intervals_data = leaf_parent_intervals->Data(); + for (T i = 0; i < seq_len; ++i) + leaf_parent_intervals_data[i] = tightest_nonempty_interval_data[ + leaf_parent_intervals_data[i]]; +} + +// Instantiate template +template +void FindTightestNonemptyIntervals(int32_t seq_len, + Array1 > *lcp_intervals, // NOLINT + Array1 *counts_exclusive_sum, + Array1 *leaf_parent_intervals); +template +void FindTightestNonemptyIntervals(int16_t seq_len, + Array1 > *lcp_intervals, // NOLINT + Array1 *counts_exclusive_sum, + Array1 *leaf_parent_intervals); + +// Internal implementation of GetBestMatchingStats(), that handles the case +// where tokens.NumAxes() == 2 and tokens.NumElements() > 0. It will +// be instantiated with int16_t if the size of the problem permits, and +// int32_t otherwise (this size is used for +template +void GetBestMatchingStatsInternal(Ragged &tokens, + Array1 &scores, + Array1 &counts, + T eos, + T min_token, + T max_token, + int32_t max_order, + Array1 *mean, + Array1 *var, + Array1 *counts_out, + Array1 *ngram_order) { + // Outline: + // First construct an array of type T which contains values as follows: + // [ tokens.values[-1]+offset, ..., tokens.values[1]+offset, + // tokens.values[0]+offset, eos+offset, terminator, 0, 0, 0 ] + // where offset is 1-min_token, and terminator is max_token+1+offset. + // The 3 terminating zeros are required by CreateSuffixArray(). + // + // Call CreateSuffixArray (seq_len == tokens.Dim() + 2, we include the + // eos and terminator). + // + // Create the reordered counts array `counts_reordered`, in the same order + // as the suffix array, then its exclusive sum, + // e.g. `counts_reordered_excsum`. At this point we can also create similar + // reordered exclusive-sums of `scores` and scores-squared; + // do these as double or roundoff will be a problem. + // + // Call CreateLcpArray, CreateLcpIntervalArray, + // FindTightestNonemptyIntervals + // + // By this point we should have enough information to directly create the + // outputs : mean, var, counts_out, ngram_order. We need to be a bit + // careful about ngram_order at positions when the suffix goes up to the + // next eos (i.e. it goes to the beginning of the sentence) because the + // correct ngram order to output here is `max_order`. You will have to + // create an array containing the distance from the beginning of the + // sentence (can be constructed from the row_ids and row_splits of `tokens`) + // + // Note: we only really care about the output at the query positions, but + // try to make it so you don't need to treat keys as a special case. + // + // Special cases/conditions to consider include: + // - make sure the `count` in the position of the eos and terminator + // are zero + // - various code may break if the total count over all these sentences is + // zero, so you could just detect that and treat it as a special case. + // If the total count is nonzero, it should be guaranteed that you never + // have to process an interval with zero count; + // FindTightestNonemptyIntervals() should guarantee that. + ContextPtr &c = tokens.Context(); + T num_elements = tokens.NumElements(); + K2_CHECK_EQ(mean->Dim(), num_elements); + K2_CHECK_EQ(var->Dim(), num_elements); + K2_CHECK_EQ(counts_out->Dim(), num_elements); + K2_CHECK_EQ(ngram_order->Dim(), num_elements); + + T offset = 1 - min_token, + terminator = max_token + 1 + offset; + // we include the eos and terminator, so plus 2 here. + T seq_len = num_elements + 2; + // 3 zeros are required by CreateSuffixArray3 + Array1 text_array(c, seq_len + 3); + T *text_array_data = text_array.Data(); + const int32_t *tokens_values_data = tokens.values.Data(); + // we want to match the longest common prefix of the word and the words + // preceding it, so we need to reverse the sequence before constructing + // suffix array. + for (T i = 0; i < num_elements; ++i) { + text_array_data[i] = + tokens_values_data[num_elements - i - 1] + offset; + } + T eos_offset = eos + offset; + // fill eos, terminator and required zeros + std::vector tail({eos_offset, terminator, 0, 0, 0}); + for (T i = num_elements; i < text_array.Dim(); ++i) + text_array_data[i] = tail[i - num_elements]; + + Array1 suffix_array(c, seq_len); + CreateSuffixArray(text_array.Data(), seq_len, terminator, + suffix_array.Data()); + + // we need extra one position for `ExclusiveSum` + Array1 reorder_counts(c, seq_len + 1); + Array1 reorder_scores(c, seq_len + 1); + Array1 reorder_scores_squre(c, seq_len + 1); + T *reorder_counts_data = reorder_counts.Data(); + float *reorder_scores_data = reorder_scores.Data(); + double *reorder_scores_squre_data = reorder_scores_squre.Data(); + + const int32_t *counts_data = counts.Data(); + const float *scores_data = scores.Data(); + const T *suffix_array_data = suffix_array.Data(); + for (int32_t i = 0; i < suffix_array.Dim(); ++i) { + // we reverse the sequence above, the order of counts and scores should be + // reversed accordingly, and make sure that the counts and scores be zero + // in the positions of eos and terminator. + int32_t rindex = seq_len - 2 - suffix_array_data[i] - 1; + reorder_counts_data[i] = rindex < 0 ? 0 : counts_data[rindex]; + reorder_scores_data[i] = rindex < 0 ? 0 : scores_data[rindex]; + reorder_scores_squre_data[i] = rindex < 0 ? 0 : + scores_data[rindex] * scores_data[rindex]; + } + ExclusiveSum(reorder_counts, &reorder_counts); + ExclusiveSum(reorder_scores, &reorder_scores); + ExclusiveSum(reorder_scores_squre, &reorder_scores_squre); + + // total count of all these sentences is zero means that there is no **keys** + // we can not match anything, return as special case. + if (reorder_counts_data[reorder_counts.Dim() - 1] == 0) { + *mean = 0; + *var = 0; + *counts_out = 0; + *ngram_order = 0; + return; + } + + Array1 lpc_array(c, seq_len); + CreateLcpArray(text_array.Data(), suffix_array.Data(), seq_len, + lpc_array.Data()); + + Array1 leaf_parent_interval; + Array1 > lcp_intervals; + CreateLcpIntervalArray(c, seq_len, lpc_array.Data(), + &lcp_intervals, &leaf_parent_interval); + + FindTightestNonemptyIntervals(seq_len, &lcp_intervals, + &reorder_counts, &leaf_parent_interval); + const LcpInterval *lcp_intervals_data = lcp_intervals.Data(); + const T *leaf_parent_interval_data = leaf_parent_interval.Data(); + + Array1 dist_to_begin(c, num_elements); + T *dist_to_begin_data = dist_to_begin.Data(); + const int32_t *tokens_row_ids1_data = tokens.RowIds(1).Data(), + *tokens_row_splits1_data = tokens.RowSplits(1).Data(); + K2_EVAL( + c, num_elements, lambda_set_dist_to_begin, (int32_t idx01) + -> void { + int32_t idx0 = tokens_row_ids1_data[idx01], + idx0x = tokens_row_splits1_data[idx0], + idx1 = idx01 - idx0x; + dist_to_begin_data[idx01] = idx1 + 1; + }); + + // mapping original order to suffix array order + Array1 inv_suffix_array(c, seq_len); + T *inv_suffix_array_data = inv_suffix_array.Data(); + for (T i = 0; i < seq_len; i++) { + inv_suffix_array_data[suffix_array_data[i]] = i; + } + float *mean_data = mean->Data(), + *var_data = var->Data(); + int32_t *counts_out_data = counts_out->Data(), + *ngram_order_data = ngram_order->Data(); + + // loop in the original order + for (T i = 0; i < num_elements; ++i) { + // we reverse `tokens.values` above, minus 2 here to remove eos and + // terminator that not belongs to tokens. + T text_array_index = seq_len - 2 - i - 1; + T suffix_index = inv_suffix_array_data[text_array_index]; + // leaf_parent_interval, reorder_counts, reorder_scores are all index with + // suffix array order. + T interval_index = leaf_parent_interval_data[suffix_index]; + LcpInterval interval = lcp_intervals_data[interval_index]; + float scores_sum = reorder_scores_data[interval.rb + 1] - + reorder_scores_data[interval.lb]; + double scores_squre_sum = reorder_scores_squre_data[interval.rb + 1] - + reorder_scores_squre_data[interval.lb]; + int32_t counts_out_interval = reorder_counts_data[interval.rb + 1] - + reorder_counts_data[interval.lb]; + if (interval.lcp == 0) { // tightest interval is root interval + K2_CHECK_EQ(interval.parent, -1); + counts_out_data[i] = 0; + ngram_order_data[i] = 0; + } else { + counts_out_data[i] = counts_out_interval; + ngram_order_data[i] = min(interval.lcp, (T)max_order); + // handle the sentence boundary + if (dist_to_begin_data[i] <= interval.lcp) + ngram_order_data[i] = max_order; + } + mean_data[i] = counts_out_interval == 0 ? 0 : + (scores_sum / counts_out_interval); + if (counts_out_interval == 0 || counts_out_interval == 1) { + var_data[i] = 0; + } else { + double numerator = scores_squre_sum - 2 * mean_data[i] * scores_sum + + counts_out_interval * mean_data[i] * mean_data[i]; + var_data[i] = numerator / counts_out_interval; + } + } +} + +void GetBestMatchingStats(Ragged &tokens, + Array1 &scores, + Array1 &counts, + int32_t eos, + int32_t min_token, + int32_t max_token, + int32_t max_order, + Array1 *mean, + Array1 *var, + Array1 *counts_out, + Array1 *ngram_order) { + ContextPtr &c = tokens.Context(); + K2_CHECK_EQ(c->GetDeviceType(), kCpu); + + int32_t num_elements = tokens.NumElements(); + K2_CHECK(mean); + if (mean->Dim() != num_elements) { + *mean = Array1(c, num_elements); + } else { + K2_CHECK_EQ(mean->Dim(), num_elements); + } + K2_CHECK(var); + if (var->Dim() != num_elements) { + *var = Array1(c, num_elements); + } else { + K2_CHECK_EQ(var->Dim(), num_elements); + } + K2_CHECK(counts_out); + if (counts_out->Dim() != num_elements) { + *counts_out = Array1(c, num_elements); + } else { + K2_CHECK_EQ(counts_out->Dim(), num_elements); + } + K2_CHECK(ngram_order); + if (ngram_order->Dim() != num_elements) { + *ngram_order = Array1(c, num_elements); + } else { + K2_CHECK_EQ(ngram_order->Dim(), num_elements); + } + + K2_CHECK(eos >= min_token && eos <= max_token); + K2_CHECK_GE(max_order, 2); + K2_CHECK_EQ(num_elements, scores.Dim()); + K2_CHECK_EQ(num_elements, counts.Dim()); + + if (tokens.NumAxes() == 3) { + int32_t num_collections = tokens.Dim0(); + for (int32_t i = 0; i < num_collections; i++) { + Ragged this_tokens = tokens.Index(0, i); + int32_t begin = this_tokens.values.Data() - tokens.values.Data(), + end = begin + this_tokens.NumElements(); + Array1 this_scores = scores.Arange(begin, end), + this_mean = mean->Arange(begin, end), + this_var = var->Arange(begin, end); + Array1 this_counts = counts.Arange(begin, end), + this_counts_out = counts_out->Arange(begin, end), + this_ngram_order = ngram_order->Arange(begin, end); + GetBestMatchingStats(this_tokens, this_scores, this_counts, eos, + min_token, max_token, max_order, + &this_mean, &this_var, + &this_counts_out, &this_ngram_order); + } + return; + } + K2_CHECK_EQ(tokens.NumAxes(), 2); // Only 2 or 3 axes are allowed. + + if (num_elements == 0) { + return; // Nothing to do. + } else if (num_elements + 10 < (1 << 15) && + (max_token - min_token + 10 < (1 << 15))) { + GetBestMatchingStatsInternal(tokens, scores, counts, eos, + min_token, max_token, max_order, + mean, var, counts_out, ngram_order); + } else { + GetBestMatchingStatsInternal(tokens, scores, counts, eos, + min_token, max_token, max_order, + mean, var, counts_out, ngram_order); + } +} + +} // namespace k2 diff --git a/k2/csrc/nbest.h b/k2/csrc/nbest.h new file mode 100644 index 000000000..acd34eb19 --- /dev/null +++ b/k2/csrc/nbest.h @@ -0,0 +1,331 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_CSRC_NBEST_H_ +#define K2_CSRC_NBEST_H_ + +#include +#include + +#include "k2/csrc/algorithms.h" +#include "k2/csrc/array.h" +#include "k2/csrc/log.h" +#include "k2/csrc/macros.h" +#include "k2/csrc/ragged.h" +#include "k2/csrc/utils.h" + +namespace k2 { + +// This header contains certain utility functions that are used in rescoring +// n-best lists: specifically, functions that help us select which among a set +// of n-best candidates to select for rescoring. The selection scheme is a +// little complex. It is intended to be used in a context where we do multiple +// successive rounds of n-best list rescoring, and we use the results of the +// 1st round to guide selection of candidates in the second round. +// So for each word in each n-best path that we are considering, we find the +// best-matching positions among those that we evaluated in the first round and +// we use those as inputs to a model that predicts the scores of words after +// n-best-list rescoring. +// +// Some of these functions may seem a little unrelated to n-best lists, they +// are algorithms involving suffix arrays, which we use internally in some +// algorithms we use to process n-best lists. + +/* + This function creates a suffix array; it is based on the + code in + https://algo2.iti.kit.edu/documents/jacm05-revised.pdf, + "Linear Work Suffix Array construction" by J. Karkkainen. + + Template args: T should be a signed integer type, we + plan to instantiate this for int32_t and int16_t only. + + @param [in] text_array Pointer to the input array of symbols, + including the termination symbol ($) which must be larger + than the other symbols. + All pointers must be CPU pointers only, for now. + The suffixes of this array are to be sorted. Logically this + array has length `seq_len`, and symbols are required + to be in the range [1..max_symbol]. + text_array is additionally required to be terminated by 3 zeros, + for purposes of this algorithm, i.e. + text_array[seq_len] == text_array[seq_len+1] == text_array[seq_len+2] == 0 + @param [in] seq_len Length of the symbol sequence (`text_array` + must be longer than this by at least 3, for termination.) + Require seq_len >= 0 + @param [out] suffix_array A pre-allocated array of length + `seq_len`. At exit it will contain a permutation of + the list [ 0, 1, ... seq_len - 1 ], interpreted + as the start indexes of the nonempty suffixes of `text_array`, + with the property that the sub-arrays of `text_array` + starting at these positions are lexicographically sorted. + For example, as a trivial case, if seq_len = 3 + and text_array contains [ 3, 2, 1, 10, 0, 0, 0 ], then + `suffix_array` would contain [ 2, 1, 0, 3 ] at exit. + @param [in] max_symbol A number that must be >= the largest + number that might be in `text_array`, including the + termination symbol. The work done + is O(seq_len + max_symbol), so it is not advisable + to let max_symbol be too large. + Caution: this function allocates memory internally (although + not much more than `text_array` itself). + */ +template +void CreateSuffixArray(const T *text_array, + T seq_len, + T max_symbol, + T *suffix_array); + +/* + Create the LCP array, which is the array of lengths of longest common prefixes + (of successive pairs of suffixes). + + Template args: T should be a signed integer type, we plan to instantiate this + for int32_t and int16_t only. + + @param [in] text_array The array of symbols, of length `seq_len` plus at least + one terminating zero. The symbols should be positive + (this may not be required here, but it is rqeuired by + CreateSuffixArray()). + @param [in] suffix_array The suffix array, as created by CreateSuffixArray(); + it is a permutation of the numbers 0, 1, ... seq_len - 1. + @param [out] lcp_array An array of length `seq_len` is output to here; + it is expected to be pre-allocated. At exit, lcp_array[0] + will be 0 and lcp_array[i] for i>0 will equal the length + of the longest common prefix of + (text_array+suffix_array[i-1]) and (text_array+suffix_array[i]). +*/ +template +void CreateLcpArray(const T *text_array, + const T *suffix_array, + T seq_len, + T *lcp_array); + +/* + Template args: T is a signed type, intended to be int16_t or int32_t + + lcp-intervals correspond to the nodes in the suffix trie; they are a concept + used with suffix arrays, and are derived from the LCP table (see lcp_array + output of CreateLcpArray). Take care with the notation here: intervals are + "closed intervals" so [i,j] means i,i+1,...,j, i.e. the RHS is the index of + the last element, not one past the last element. + + Notation: [i,j] is an lcp-interval with lcp-value l, if: + 0 <= i < j < seq_len + lcptab[i] < l + lcptab[j+1] < l + l is the minimum of (lcptab[i+1], lcptab[i+2], ..., lcptab[j]) + lcp-intervals correspond to the internal nodes of the suffix trie, so + they always contain at least two children (where children can be + leaves, corresponding indexes into the suffix array, or other + lcp-intervals). + + SPECIAL CASE: if seq_len is 1, which is a rather degenerate case, the above + definitions do not quite work; and we treat [0,0] as an lcp-interval with + lcp-value 0 although it does not meet the above definitions. + + Type LcpInterval is used to store information about the lcp interval, + which we'll later use in algorithms that traverse the suffix tree. + */ +template +struct LcpInterval { + // Represents the lcp-interval [begin,last] with lcp-value `lcp` + T lcp; // The lcp-value of the lcp-interval, which is the length of the + // longest prefix shared by all suffixes in this interval. + T lb; // Index of the first element (left boundary) + T rb; // Index of the last elemen (right boundary) + T parent; // The parent of this lcp-interval + // (-1 if this is the top interval), + // in the order in which it appears in this array (of + // lcp-intervals). Note: this order is neither top-down or + // bottom-up; you can treat it as arbitrary. +}; + +template +std::ostream &operator<<(std::ostream &os, const LcpInterval &interval) { + static constexpr char kSep = ' '; + os << interval.lcp << kSep << interval.lb << kSep << interval.rb << kSep + << interval.parent; + return os; +} + + +/* + Create an array of struct LcpInterval which describes the Lcp intervals + corresponding to the internal nodes of the suffix tree, and allows you + to easily run algorithms on this tree. This data structure is not + very memory-optimized and doesn't correspond to anything in the literature, + although the basic tree traversal algorithm comes from + [Mohamed Ibrahim Abouelhoda, Stefan Kurtz, Enno Ohlebusch: Replacing suffix + trees with enhanced suffix arrays. Journal of Discrete Algorithms 2 (2004) + 53-86.] and was originally adapted from [Kasai, Lee, Arimura, Arikawa, Park: + Linear-Time Longest-Common-Prefix Computation in Suffix Arrays and Its + Applications, CPM 2001]. + + The motivation here is that we are likely limited more by time than memory, + and we want a data structure that is relatively simple to use. + + Template args: T is a signed type, intended to be int16_t or int32_t + + @param [in] c Context pointer, used to create arrays. Required to + be a CPU context pointer for now. + @param [in] seq_len The length of the text for which we have a suffix + array + @param [in] lcp_array The LCP array, as computed by CreateLcpArray() + @param [out] lcp_intervals A *newly created* array of LcpInterval + will be written to here, of length no greater than seq_len. + They will be in dfs post order. Children precede their + parents. + @param [out] leaf_parent_intervals If this is non-NULL, a newly + created array of size seq_len will be written to here, + saying, for each leaf in the suffix tree (corresponding to + positions in the suffix array) which lcp-interval + is its parent. Indexes into this array correspond to + indexes into the suffix array, and values correspond + to indexes into `lcp_intervals`. + */ +template +void CreateLcpIntervalArray(ContextPtr c, + T seq_len, + T *lcp_array, + Array1 > *lcp_intervals, + Array1 *leaf_parent_intervals); + +/* + Modifies `leaf_parent_intervals` to give us, for each position in the suffix + array (i.e. each suffix), the tightest enclosing lcp-interval that has + nonzero count. This is used in finding the highest-order match of + a position in a text (i.e. the longest matching history). + + Template args: T is a signed type, intended to be int16_t or int32_t + + @param [in] seq_len The length of the sequence, + including the terminating $. + @param [in] lcp_intervals The array of lcp intervals, as returned + by CreateLcpIntervalArray + @param [in] counts_exclusive_sum The exclusive-sum of counts of symbols in + the original text array, in the order given by the suffix + array, e.g. the original counts would have satisfied + suffix_counts[i] = counts[suffix_array[i]], and then + counts_exclusive_sum is the exclusive-sum of suffix_counts. + Must satisfy counts_exclusive_sum->Dim() == seq_len + 1. + The original counts would have been 1 for "keys" and 0 for + "queries", so an interval with nonzero difference in + counts_exclusive_sum is an interval containing at least + one key. + @param [in,out] leaf_parent_intervals At entry, this will contain, + for each leaf of the suffix tree (i.e. each position + in the suffix array) the index of the tightest enclosing + lcp-interval, i.e. an index into `lcp_intervals`. + At exit, it will contain the index of the tightest + enclosing *nonempty* lcp-interval. + */ +template +void FindTightestNonemptyIntervals(T seq_len, + Array1 > *lcp_intervals, + Array1 *counts_exclusive_sum, + Array1 *leaf_parent_intervals); + +/* + For "query" sentences, this function gets the mean and variance of scores + from the best matching words-in-context in a set of provided "key" + sentences. This matching process matches the word and the words preceding + it, looking for the highest-order match it can find (it's intended for + approximating the scores of models that see only left-context, like language + models). It is an efficient implementation using suffix arrays (done on CPU + for now, since the implementation is not very trivial). The intended + application is in estimating the scores of hypothesized transcripts, when we + have actually computed the scores for only a subset of the hypotheses. + + @param [in] tokens A ragged tensor of int32_t with 2 or 3 axes (this + function recurses). If 2 axes, this represents a collection of + key and query sequences (keys have count==1, query count==0). + If 3 axes, this represents a set of such collections + and retrieval should be done independently. + + 2-axis example: + [ [ the, cat, said, eos ], [ the, cat, fed, eos ] ] + 3-axis example: + [ [ [ the, cat, said, eos ], [ the, cat, fed, eos ] ], + [ [ hi, my, name, is, eos ], [ bye, my, name, is, eos ] ], ... ] + + .. where the words would actually be represented as integers, + and the eos might be -1. The eos symbol is required if this + code is to work as intended (otherwise this code will not + be able to recognize when we have reached the beginnings + of sentences when comparing histories). bos symbols are + allowed but not required. + + @param [in] scores An array with scores.Dim() == tokens.NumElements(); + this is the item for which we are requesting best-matching + values (as means and variances in case there are multiple + best matches). In our anticipated use, these would represent + scores of words in the sentences, but they could represent + anything. + @param [in] counts An array with counts.Dim() == tokens.NumElements(), + containing 1 for words that are considered "keys" and 0 for + words that are considered "queries". Typically some entire + sentences will be keys and others will be queries. + @param [in] eos The value of the eos (end of sentence) symbol; internally, this + is used as an extra padding value before the first sentence in each + collection, so that it can act like a "bos" symbol. + @param [in] min_token The lowest possible token value, including the bos + symbol (e.g., might be -1). + @param [in] max_token The maximum possible token value. Be careful not to + set this too large the implementation contains a part which + takes time and space O(max_token - min_token). + @param [in] max_order The maximum n-gram order to ever return in the + `ngram_order` output; the output will be the minimum of max_order + and the actual order matched; or max_order if we matched all the + way to the beginning of both sentences. The main reason this is + needed is that we need a finite number to return at the + beginning of sentences. + + @param [out] mean For query positions, will contain the mean of the + scores at the best matching key positions, or zero if that is + undefined because there are no key positions at all. For key positions, + you can treat the output as being undefined (actually they + are treated the same as queries, but won't match with only + themselves because we don't match at singleton intervals). This array + will be allocated if it did not have the correct size at + entry. + @param [out] var Like `mean`, but contains the (centered) variance + of the best matching positions. + @param [out] counts_out The number of key positions that contributed + to the `mean` and `var` statistics. This should only + be zero if `counts` was all zero. Will be allocated + if it did not have the correct size at entry. + @param [out] ngram_order The n-gram order corresponding to the + best matching positions found at each query position, up + to a maximum of `max_order`; will be `max_order` if we matched all + the way to the beginning of a sentence. Will be allocated if it + did not have the correct size at entry. +*/ +void GetBestMatchingStats(Ragged &tokens, + Array1 &scores, + Array1 &counts, + int32_t eos, + int32_t min_token, + int32_t max_token, + int32_t max_order, + Array1 *mean, + Array1 *var, + Array1 *counts_out, + Array1 *ngram_order); +} // namespace k2 +#endif // K2_CSRC_NBEST_H_ diff --git a/k2/csrc/nbest_test.cu b/k2/csrc/nbest_test.cu new file mode 100644 index 000000000..8bc4600e3 --- /dev/null +++ b/k2/csrc/nbest_test.cu @@ -0,0 +1,444 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey + * Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "k2/csrc/nbest.h" +#include "k2/csrc/ragged.h" +#include "k2/csrc/ragged_ops.h" + + +namespace k2 { +TEST(AlgorithmsTest, TestSuffixArray) { + ContextPtr cpu = GetCpuContext(); + + for (int i = 0; i < 100; i++) { + int array_len = RandInt(1, 50), // 1 is min, due to termination symbol. + max_symbol = RandInt(10, 500); + + Array1 array(cpu, array_len + 3); + int32_t *array_data = array.Data(); + for (int i = 0; i + 1 < array_len; i++) + array_data[i] = RandInt(1, max_symbol - 1); // termination symbol must + // be larger than all + // others, don't allow + array_data[array_len - 1] = max_symbol; // Termination symbol + + for (int i = array_len; i < array_len + 3; i++) + array_data[i] = 0; + + // really array_len, extra elem is to test that it doesn't write past + // the end. + Array1 suffix_array(cpu, array_len + 1); + int32_t *suffix_array_data = suffix_array.Data(); + suffix_array_data[array_len] = -10; // should not be changed. + CreateSuffixArray(array_data, array_len, + max_symbol, suffix_array_data); + assert(suffix_array_data[array_len] == -10); // should be unchanged. + Array1 seen_indexes(cpu, array_len, 0); + int32_t *seen_indexes_data = seen_indexes.Data(); + for (int32_t i = 0; i < array_len; i++) + seen_indexes_data[suffix_array_data[i]] = 1; + + for (int32_t i = 0; i < array_len; i++) + assert(seen_indexes_data[i] == 1); // make sure all integers seen. + for (int32_t i = 0; i + 1 < array_len; i++) { + int32_t *suffix_a = array_data + suffix_array_data[i], + *suffix_b = array_data + suffix_array_data[i + 1]; + // checking that each suffix is lexicographically less than the next one. + // None are identical, because the terminating zero is always in different + // positions. + while (true) { + if (*suffix_a < *suffix_b) + break; // correct order + assert(!(*suffix_a > *suffix_b)); // order is wrong! + // past array end without correct comparison order. + assert(!(suffix_a > array_data + array_len || + suffix_b > array_data + array_len)); + suffix_a++; + suffix_b++; + } + } + } +} + +TEST(AlgorithmsTest, TestCreateLcpArray) { + ContextPtr cpu = GetCpuContext(); + + for (int i = 0; i < 100; i++) { + int array_len = RandInt(1, 50), // at least 1 due to termination symbol + max_symbol = RandInt(2, 5); + + Array1 array(cpu, array_len + 3); + int32_t *array_data = array.Data(); + for (int i = 0; i + 1 < array_len; i++) + array_data[i] = RandInt(1, max_symbol - 1); + array_data[array_len - 1] = max_symbol; // Termination symbol + for (int i = array_len; i < array_len + 3; i++) + array_data[i] = 0; + + Array1 suffix_array(cpu, array_len); + int32_t *suffix_array_data = suffix_array.Data(); + CreateSuffixArray(array_data, array_len, + max_symbol, suffix_array_data); + + Array1 lcp(cpu, array_len); + int32_t *lcp_data = lcp.Data(); + CreateLcpArray(array_data, suffix_array_data, array_len, + lcp_data); + if (array_len > 0) + assert(lcp_data[0] == 0); + for (int32_t i = 1; i < array_len; i++) { + int32_t lcp = lcp_data[i], + prev_pos = suffix_array_data[i - 1], + this_pos = suffix_array_data[i]; + for (int32_t j = 0; j < lcp; j++) + assert(array_data[prev_pos + j] == array_data[this_pos + j]); + assert(array_data[prev_pos + lcp] != array_data[this_pos + lcp]); + } + } +} + +TEST(AlgorithmsTest, TestCreateLcpIntervalArray) { + ContextPtr cpu = GetCpuContext(); + + for (int i = 0; i < 100; i++) { + int array_len = RandInt(1, 50), // at least 1 due to termination symbol + max_symbol = RandInt(3, 5); + + Array1 array(cpu, array_len + 3); + int32_t *array_data = array.Data(); + for (int i = 0; i + 1 < array_len; i++) + array_data[i] = RandInt(1, max_symbol - 1); + array_data[array_len - 1] = max_symbol; // Termination symbol + for (int i = array_len; i < array_len + 3; i++) + array_data[i] = 0; + + Array1 suffix_array(cpu, array_len); + int32_t *suffix_array_data = suffix_array.Data(); + CreateSuffixArray(array_data, array_len, + max_symbol, suffix_array_data); + + Array1 lcp(cpu, array_len); + int32_t *lcp_data = lcp.Data(); + CreateLcpArray(array_data, suffix_array_data, array_len, + lcp_data); + + Array1 > lcp_intervals; + Array1 leaf_parent_intervals; + + CreateLcpIntervalArray(GetCpuContext(), + array_len, lcp_data, + &lcp_intervals, + &leaf_parent_intervals); + + LcpInterval *lcp_intervals_data = lcp_intervals.Data(); + int32_t *leaf_parent_intervals_data = leaf_parent_intervals.Data(); + int32_t num_intervals = lcp_intervals.Dim(); + for (int32_t i = 0; i < array_len; i++) { + int32_t lcp_interval = leaf_parent_intervals_data[i]; + assert(lcp_interval >= 0 && lcp_interval < num_intervals); + assert(i >= lcp_intervals_data[lcp_interval].lb && + i <= lcp_intervals_data[lcp_interval].rb); + // the lcp value / height + int32_t lcp = lcp_intervals_data[lcp_interval].lcp; + + for (int32_t j = 0; j < num_intervals; j++) { + // The interval that i is a member of should be the tightest enclosing + // interval, this loop checks that. + if (lcp_intervals_data[j].lcp >= lcp && j != lcp_interval) { + assert(!(i >= lcp_intervals_data[j].lb && + i <= lcp_intervals_data[j].rb)); + } + } + } + + for (int32_t i = 0; i < num_intervals; i++) { + LcpInterval interval = lcp_intervals_data[i]; + if (!(interval.lb == 0 && interval.rb + 1 == array_len && + interval.parent == -1)) { + assert(interval.parent > i); + LcpInterval parent = lcp_intervals_data[interval.parent]; + assert(interval.lb >= parent.lb && + interval.rb <= parent.rb && + interval.lcp > parent.lcp); + } + // Now check the basic requirements/definition of lcp interval... + assert(interval.lb >= 0 && + (interval.rb > interval.lb || array_len == 1) && + interval.rb < array_len); + assert(lcp_data[interval.lb] < interval.lcp || + (interval.lb == 0 && interval.lcp == 0)); + assert(interval.rb == array_len - 1 || + lcp_data[interval.rb + 1] < interval.lcp); + if (array_len != 1) { + int32_t min_lcp = 1000000; + for (int32_t j = interval.lb + 1; j <= interval.rb; ++j) + if (lcp_data[j] < min_lcp) + min_lcp = lcp_data[j]; + assert(min_lcp == interval.lcp); // Check lcp value is correct. This + // test does not work if array_len == + // 1 so we skip it in that case. + } + } + } +} + +TEST(AlgorithmsTest, TestFindTightestNonemptyIntervals) { + ContextPtr cpu = GetCpuContext(); + + for (int i = 0; i < 100; i++) { + int array_len = RandInt(1, 50), // at least 1 due to termination symbol + max_symbol = RandInt(3, 5); + + Array1 array(cpu, array_len + 3), + counts(cpu, array_len); + int32_t *array_data = array.Data(); + for (int i = 0; i + 1 < array_len; i++) + array_data[i] = RandInt(1, max_symbol - 1); + array_data[array_len - 1] = max_symbol; // Termination symbol + for (int i = array_len; i < array_len + 3; i++) + array_data[i] = 0; + + int32_t *counts_data = counts.Data(); + for (int i = 0; i < array_len; i++) + counts_data[i] = RandInt(0, 1); + + Array1 suffix_array_plusone(cpu, array_len + 1, 0), + suffix_array = suffix_array_plusone.Range(0, array_len); + int32_t *suffix_array_data = suffix_array.Data(); + CreateSuffixArray(array_data, array_len, + max_symbol, suffix_array_data); + + Array1 lcp(cpu, array_len); + int32_t *lcp_data = lcp.Data(); + CreateLcpArray(array_data, suffix_array_data, array_len, + lcp_data); + + Array1 > lcp_intervals; + Array1 leaf_parent_intervals; // dim will be seq_len + + CreateLcpIntervalArray(GetCpuContext(), + array_len, lcp_data, + &lcp_intervals, + &leaf_parent_intervals); + // we get one extra don't-care element at the end of `counts_reordered`, + // which is required by ExclusiveSum(). + Array1 counts_reordered = counts[suffix_array_plusone], + counts_reordered_sum(cpu, array_len + 1); + ExclusiveSum(counts_reordered, &counts_reordered_sum); + + + Array1 leaf_parent_intervals_mod(leaf_parent_intervals.Clone()); + + FindTightestNonemptyIntervals(array_len, + &lcp_intervals, + &counts_reordered_sum, + &leaf_parent_intervals_mod); + + LcpInterval *lcp_intervals_data = lcp_intervals.Data(); + int32_t *leaf_parent_intervals_data = leaf_parent_intervals.Data(), + *leaf_parent_intervals_mod_data = leaf_parent_intervals_mod.Data(); + + int32_t num_intervals = lcp_intervals.Dim(); + for (int32_t i = 0; i < array_len; i++) { + int32_t lcp_interval = leaf_parent_intervals_data[i], + nonempty_lcp_interval = leaf_parent_intervals_mod_data[i]; + assert(lcp_interval >= 0 && lcp_interval < num_intervals); + assert(nonempty_lcp_interval >= 0 && + nonempty_lcp_interval < num_intervals); + if (counts_reordered_sum[array_len] == 0) { + // If the total count is zero, everything should go to the top of the + // tree, but we won't otherwise test this. + assert(nonempty_lcp_interval == num_intervals - 1); + } else { + int32_t lcp = lcp_intervals_data[lcp_interval].lcp; + K2_CHECK_EQ((lcp_interval == nonempty_lcp_interval), + (counts_reordered_sum[lcp_intervals_data[lcp_interval].lb] != // NOLINT + counts_reordered_sum[lcp_intervals_data[lcp_interval].rb + 1])); // NOLINT + K2_CHECK(i >= lcp_intervals_data[nonempty_lcp_interval].lb && + i <= lcp_intervals_data[nonempty_lcp_interval].rb); + + for (int32_t j = 0; j < num_intervals; j++) { + // nonempty_lcp_interval should be the tightest enclosing + // interval that has nonzero count, this loop checks that. + if (lcp_intervals_data[j].lcp >= lcp && j != nonempty_lcp_interval) { + // Check that this is not a tighter enclosing interval than + // nonempty_lcp_interval, with nonzero count, that encloses i. + K2_CHECK(!(i >= lcp_intervals_data[j].lb && + i <= lcp_intervals_data[j].rb && + counts_reordered_sum[lcp_intervals_data[j].lb] != + counts_reordered_sum[lcp_intervals_data[j].rb + 1])); + } + } + } + } + } +} + +TEST(AlgorithmTest, TestGetBestMatchingStatsEmpty) { + Ragged tokens(GetCpuContext(), "[ [ [ ] ] ]"); + Array1 scores(GetCpuContext(), "[ ]"); + Array1 counts(GetCpuContext(), "[ ]"); + Array1 mean, var; + Array1 counts_out, ngram_order; + int32_t eos = 8, + min_token = 1, + max_token = 8, + max_order = 2; + GetBestMatchingStats(tokens, scores, counts, eos, min_token, max_token, + max_order, &mean, &var, &counts_out, &ngram_order); + + K2_CHECK_EQ(mean.Dim(), 0); + K2_CHECK_EQ(var.Dim(), 0); + K2_CHECK_EQ(counts_out.Dim(), 0); + K2_CHECK_EQ(ngram_order.Dim(), 0); +} + +TEST(AlgorithmTest, TestGetBestMatchingStatsSingle) { + // There are 20 tokens, index with [0, 20) + // keys' positions are [0, 10), queries positions are [10, 20) + // The best matching positions(include the token itself) are as follows + // index 0 : (0, 5, 10) with lcp "84", we add eos(8) + // index 1 : (1, 16,) with lcp "6" + // index 2 : (2, 17,) with lcp "76" + // index 3 : (3, 18,) with lcp "671" + // index 4 : (4, 19,) with lcp "5718" + // index 5 : (5, 10,) with lcp "7184" + // index 6 : (6, 11,) with lcp "43" + // index 7 : (2, 7, 17,) with lcp "7" + // index 8 : (3, 8, 18,) with lcp "71" + // index 9 : (4, 9, 19,) with lcp "718" + // index 10 : (5, 10,) with lcp "7184" + // index 11 : (6, 11,) with lcp "43" + // index 12 : (12,) with no matching + // index 13 : (3, 8, 13, 18,) with lcp "1" + // index 14 : (4, 9, 14, 19,) with lcp "18" + // index 15 : (15,) with no matching + // index 16 : (1, 16,) with lcp "6" + // index 17 : (2, 17,) with lcp "67" + // index 18 : (3, 18,) with lcp "671" + // index 19 : (4, 19,) with lcp "6718" + Ragged tokens(GetCpuContext(), "[ [ 4 6 7 1 8 ] [ 4 3 7 1 8 ] " + " [ 4 3 2 1 8 ] [ 5 6 7 1 8 ] ]"); + Array1 scores(GetCpuContext(), "[ 1 2 3 4 5 6 7 8 9 10 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 counts(GetCpuContext(), "[ 1 1 1 1 1 1 1 1 1 1 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 mean, var; + Array1 counts_out, ngram_order; + int32_t eos = 8, + min_token = 1, + max_token = 8, + max_order = 2; + GetBestMatchingStats(tokens, scores, counts, eos, min_token, max_token, + max_order, &mean, &var, &counts_out, &ngram_order); + Array1 mean_ref(GetCpuContext(), "[ 3.5 2 3 4 5 6 7 5.5 6.5 7.5 " + " 6 7 5.5 6.5 7.5 5.5 2 3 4 5 ]"); + Array1 var_ref(GetCpuContext(), "[ 6.25 0 0 0 0 0 0 6.25 6.25 6.25 " + " 0 0 8.25 6.25 6.25 8.25 0 0 0 0 ]"); + Array1 counts_out_ref(GetCpuContext(), "[ 2 1 1 1 1 1 1 2 2 2 " + " 1 1 0 2 2 0 1 1 1 1 ]"); + Array1 ngram_order_ref(GetCpuContext(), "[ 2 1 2 2 2 2 2 1 2 2 " + " 2 2 0 1 2 0 1 2 2 2 ]"); + K2_CHECK(Equal(mean, mean_ref)); + K2_CHECK(Equal(var, var_ref)); + K2_CHECK(Equal(counts_out, counts_out_ref)); + K2_CHECK(Equal(ngram_order, ngram_order_ref)); +} + +TEST(AlgorithmTest, TestGetBestMatchingStatsSpecial) { + Ragged tokens(GetCpuContext(), "[ [ 4 6 7 1 8 ] [ 4 3 7 1 8 ] " + " [ 4 3 2 1 8 ] [ 5 6 7 1 8 ] ]"); + Array1 scores(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 counts(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 mean, var; + Array1 counts_out, ngram_order; + int32_t eos = 8, + min_token = 1, + max_token = 8, + max_order = 2; + GetBestMatchingStats(tokens, scores, counts, eos, min_token, max_token, + max_order, &mean, &var, &counts_out, &ngram_order); + Array1 mean_ref(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 var_ref(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 counts_out_ref(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 ngram_order_ref(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + K2_CHECK(Equal(mean, mean_ref)); + K2_CHECK(Equal(var, var_ref)); + K2_CHECK(Equal(counts_out, counts_out_ref)); + K2_CHECK(Equal(ngram_order, ngram_order_ref)); +} + +TEST(AlgorithmTest, TestGetBestMatchingStatsSingleMulti) { + Ragged tokens(GetCpuContext(), "[ [ [ 4 6 7 1 8 ] [ 4 3 7 1 8 ] " + " [ 4 3 2 1 8 ] [ 5 6 7 1 8 ] ] " + " [ [ 5 1 4 8 ] [ 5 1 2 8 ] " + " [ 5 3 4 8 ] ] ]"); + Array1 scores(GetCpuContext(), "[ 1 2 3 4 5 6 7 8 9 10 " + " 0 0 0 0 0 0 0 0 0 0 " + " 1 2 3 4 5 7 8 6 0 0 0 0 ]"); + Array1 counts(GetCpuContext(), "[ 1 1 1 1 1 1 1 1 1 1 " + " 0 0 0 0 0 0 0 0 0 0 " + " 1 1 1 1 1 1 1 1 0 0 0 0 ]"); + Array1 mean, var; + Array1 counts_out, ngram_order; + int32_t eos = 8, + min_token = 0, + max_token = 10, + max_order = 5; + GetBestMatchingStats(tokens, scores, counts, eos, min_token, max_token, + max_order, &mean, &var, &counts_out, &ngram_order); + Array1 mean_ref(GetCpuContext(), "[ 3.5 2 3 4 5 6 7 5.5 6.5 7.5 " + " 6 7 5.5 6.5 7.5 5.5 2 3 4 5 " + " 3 4.5 3 4 3 4.5 4.5 5 " + " 3 4.5 3 4 ]"); + Array1 var_ref(GetCpuContext(), "[ 6.25 0 0 0 0 0 0 6.25 6.25 6.25 " + " 0 0 8.25 6.25 6.25 8.25 0 0 0 0 " + " 4 6.25 0 0 4 6.25 5.25 1 " + " 4 5.25 0 0 ]"); + Array1 counts_out_ref(GetCpuContext(), "[ 2 1 1 1 1 1 1 2 2 2 " + " 1 1 0 2 2 0 1 1 1 1 " + " 2 2 1 1 2 2 0 2 " + " 2 0 1 1 ]"); + Array1 ngram_order_ref(GetCpuContext(), "[ 5 1 2 3 4 5 5 1 2 3 " + " 5 5 0 1 2 0 1 2 3 4 " + " 5 5 1 2 5 5 0 1 " + " 5 0 1 2 ]"); + K2_CHECK(Equal(mean, mean_ref)); + K2_CHECK(Equal(var, var_ref)); + K2_CHECK(Equal(counts_out, counts_out_ref)); + K2_CHECK(Equal(ngram_order, ngram_order_ref)); +} + +} // namespace k2 diff --git a/k2/python/csrc/torch.cu b/k2/python/csrc/torch.cu index 0e6851488..46f0aa819 100644 --- a/k2/python/csrc/torch.cu +++ b/k2/python/csrc/torch.cu @@ -30,6 +30,7 @@ #include "k2/python/csrc/torch/fsa_algo.h" #include "k2/python/csrc/torch/index_add.h" #include "k2/python/csrc/torch/index_select.h" +#include "k2/python/csrc/torch/nbest.h" #include "k2/python/csrc/torch/ragged.h" #include "k2/python/csrc/torch/ragged_ops.h" @@ -40,6 +41,7 @@ void PybindTorch(py::module &m) { PybindFsaAlgo(m); PybindIndexAdd(m); PybindIndexSelect(m); + PybindNbest(m); PybindRagged(m); PybindRaggedOps(m); } diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt index 80cbe5533..9c6a82569 100644 --- a/k2/python/csrc/torch/CMakeLists.txt +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -6,6 +6,7 @@ set(torch_srcs fsa_algo.cu index_add.cu index_select.cu + nbest.cu ragged.cu ragged_ops.cu torch_util.cu diff --git a/k2/python/csrc/torch/discounted_cum_sum.cu b/k2/python/csrc/torch/discounted_cum_sum.cu index 2aba8135d..51cc9d8bb 100644 --- a/k2/python/csrc/torch/discounted_cum_sum.cu +++ b/k2/python/csrc/torch/discounted_cum_sum.cu @@ -2,7 +2,7 @@ * @brief wraps discounted_cum_sum code. * * @copyright - * Copyright 2010 Xiaomi Corp. (authors: Daniel Povey) + * Copyright 2021 Xiaomi Corp. (authors: Daniel Povey) * * @copyright * See LICENSE for clarification regarding multiple authors diff --git a/k2/python/csrc/torch/nbest.cu b/k2/python/csrc/torch/nbest.cu new file mode 100644 index 000000000..ad2b7f094 --- /dev/null +++ b/k2/python/csrc/torch/nbest.cu @@ -0,0 +1,62 @@ +/** + * @brief wraps nbest code. + * + * @copyright + * Copyright 2021 Xiaomi Corp. (authors: Wei Kang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "k2/csrc/context.h" +#include "k2/csrc/device_guard.h" +#include "k2/csrc/macros.h" +#include "k2/csrc/nbest.h" +#include "k2/csrc/nvtx.h" +#include "k2/csrc/tensor_ops.h" +#include "k2/python/csrc/torch/nbest.h" +#include "k2/python/csrc/torch/torch_util.h" + +namespace k2 { + +static void PybindGetBestMatchingStats(py::module &m) { + m.def( + "get_best_matching_stats", + [](Ragged &tokens, torch::Tensor scores, torch::Tensor counts, + int32_t eos, int32_t min_token, int32_t max_token, + int32_t max_order) -> std::tuple { + DeviceGuard guard(tokens.Context()); + Array1 scores_array = FromTorch(scores); + Array1 counts_array = FromTorch(counts); + Array1 mean, var; + Array1 counts_out, ngram_order; + GetBestMatchingStats(tokens, scores_array, counts_array, + eos, min_token, max_token, max_order, + &mean, &var, &counts_out, &ngram_order); + return std::make_tuple(ToTorch(mean), ToTorch(var), + ToTorch(counts_out), ToTorch(ngram_order)); + }, + py::arg("tokens"), py::arg("scores"), py::arg("counts"), py::arg("eos"), + py::arg("min_token"), py::arg("max_token"), py::arg("max_order")); +} + +} // namespace k2 + +void PybindNbest(py::module &m) { + k2::PybindGetBestMatchingStats(m); +} diff --git a/k2/python/csrc/torch/nbest.h b/k2/python/csrc/torch/nbest.h new file mode 100644 index 000000000..a6d5e805e --- /dev/null +++ b/k2/python/csrc/torch/nbest.h @@ -0,0 +1,30 @@ +/** + * @brief python wrapper for nbest.h + * + * @copyright + * Copyright 2021 Xiaomi Corp. (authors: Wei Kang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_PYTHON_CSRC_TORCH_NBEST_H_ +#define K2_PYTHON_CSRC_TORCH_NBEST_H_ + +#include "k2/python/csrc/k2.h" + +void PybindNbest(py::module &m); + +#endif // K2_PYTHON_CSRC_TORCH_NBEST_H_ diff --git a/k2/python/csrc/torch/ragged_ops.cu b/k2/python/csrc/torch/ragged_ops.cu index 61a18c2e1..45496fba4 100644 --- a/k2/python/csrc/torch/ragged_ops.cu +++ b/k2/python/csrc/torch/ragged_ops.cu @@ -411,6 +411,7 @@ void PybindRaggedOps(py::module &m) { PybindArgMaxPerSublist(m); PybindArgMaxPerSublist(m); PybindCat(m); + PybindCat(m); PybindCat(m); PybindCreateRagged2(m); PybindCreateRagged2(m); diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index 2e17fa7ab..aee16baa9 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -52,12 +52,13 @@ from .ragged import RaggedShape from .symbol_table import SymbolTable from .utils import create_fsa_vec +from .utils import create_sparse from .utils import is_rand_equivalent +from .utils import get_best_matching_stats from .utils import to_dot from .utils import to_str from .utils import to_str_simple from .utils import to_tensor -from .utils import create_sparse from .utils import random_fsa from .utils import random_fsa_vec from _k2.version import with_cuda diff --git a/k2/python/k2/utils.py b/k2/python/k2/utils.py index 962476ef4..7b5bb9a24 100644 --- a/k2/python/k2/utils.py +++ b/k2/python/k2/utils.py @@ -656,3 +656,93 @@ def random_fsa_vec(min_num_fsas: int = 1, random_arcs = _k2.random_fsa_vec(min_num_fsas, max_num_fsas, acyclic, max_symbol, min_num_arcs, max_num_arcs) return Fsa(random_arcs) + + +def get_best_matching_stats(tokens: _k2.RaggedInt, scores: torch.Tensor, + counts: torch.Tensor, eos: int, min_token: int, + max_token: int, max_order: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # noqa + '''For "query" sentences, this function gets the mean and variance of + scores from the best matching words-in-context in a set of provided "key" + sentences. This matching process matches the word and the words preceding + it, looking for the highest-order match it can find (it's intended for + approximating the scores of models that see only left-context, + like language models). The intended application is in estimating the scores + of hypothesized transcripts, when we have actually computed the scores for + only a subset of the hypotheses. + + CAUTION: + This function only runs on CPU for now. + + Args: + tokens: + A ragged tensor of int32_t with 2 or 3 axes. If 2 axes, this represents + a collection of key and query sequences. If 3 axes, this represents a + set of such collections. + + 2-axis example: + [ [ the, cat, said, eos ], [ the, cat, fed, eos ] ] + 3-axis example: + [ [ [ the, cat, said, eos ], [ the, cat, fed, eos ] ], + [ [ hi, my, name, is, eos ], [ bye, my, name, is, eos ] ], ... ] + + where the words would actually be represented as integers, + The eos symbol is required if this code is to work as intended + (otherwise this code will not be able to recognize when we have reached + the beginnings of sentences when comparing histories). + bos symbols are allowed but not required. + + scores: + A one dim torch.tensor with scores.size() == tokens.NumElements(), + this is the item for which we are requesting best-matching values + (as means and variances in case there are multiple best matches). + In our anticipated use, these would represent scores of words in the + sentences, but they could represent anything. + counts: + An one dim torch.tensor with counts.size() == tokens.NumElements(), + containing 1 for words that are considered "keys" and 0 for + words that are considered "queries". Typically some entire + sentences will be keys and others will be queries. + eos: + The value of the eos (end of sentence) symbol; internally, this + is used as an extra padding value before the first sentence in each + collection, so that it can act like a "bos" symbol. + min_token: + The lowest possible token value, including the bos + symbol (e.g., might be -1). + max_token: + The maximum possible token value. Be careful not to + set this too large the implementation contains a part which + takes time and space O(max_token - min_token). + max_order: + The maximum n-gram order to ever return in the + `ngram_order` output; the output will be the minimum of max_order + and the actual order matched; or max_order if we matched all the + way to the beginning of both sentences. The main reason this is + needed is that we need a finite number to return at the + beginning of sentences. + + Returns: + Returns a tuple of four torch.tensor (mean, var, counts_out, ngram_order) + mean: + For query positions, will contain the mean of the scores at the + best matching key positions, or zero if that is undefined because + there are no key positions at all. For key positions, + you can treat the output as being undefined (actually they + are treated the same as queries, but won't match with only + themselves because we don't match at singleton intervals). + var: + Like `mean`, but contains the (centered) variance + of the best matching positions. + counts_out: + The number of key positions that contributed to the `mean` + and `var` statistics. This should only be zero if `counts` + was all zero. + ngram_order: + The n-gram order corresponding to the best matching + positions found at each query position, up to a maximum of + `max_order`; will be `max_order` if we matched all + the way to the beginning of a sentence. + ''' + return _k2.get_best_matching_stats(tokens, scores, counts, eos, + min_token, max_token, max_order) diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 5d32a5e3e..08f96de1a 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -36,6 +36,7 @@ set(py_test_files get_arc_post_test.py get_backward_scores_test.py get_forward_scores_test.py + get_best_matching_stats_test.py get_tot_scores_test.py index_add_test.py index_and_sum_test.py diff --git a/k2/python/tests/get_best_matching_stats_test.py b/k2/python/tests/get_best_matching_stats_test.py new file mode 100644 index 000000000..9430c577e --- /dev/null +++ b/k2/python/tests/get_best_matching_stats_test.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (author: Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R get_best_matching_stats_test_py + +import unittest + +import k2 +import torch + + +class TestGetBestMatchingStats(unittest.TestCase): + + def test(self): + s = '[ [ [ 5 1 4 6 ] [ 5 1 2 6 ] [ 5 3 4 6 ] ] ]' + tokens = k2.RaggedInt(s) + scores = torch.tensor([1, 2, 3, 4, 5, 7, 8, 6, 0, 0, 0, 0], + dtype=torch.float32) + counts = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + dtype=torch.int32) + eos = 6 + min_token = 1 + max_token = 6 + max_order = 2 + mean, var, counts_out, ngram_order = k2.get_best_matching_stats( + tokens, scores, counts, eos, min_token, max_token, max_order) + + mean_ref = torch.tensor([3, 4.5, 3, 4, 3, 4.5, 4.5, 5, 3, 4.5, 3, 4], + dtype=torch.float32) + var_ref = torch.tensor([4, 6.25, 0, 0, 4, 6.25, 5.25, 1, 4, 5.25, 0, 0], + dtype=torch.float32) + counts_out_ref = torch.tensor([2, 2, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1], + dtype=torch.int32) + ngram_order_ref = torch.tensor([2, 2, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2], + dtype=torch.int32) + assert torch.allclose(mean, mean_ref) + assert torch.allclose(var, var_ref) + assert torch.all(torch.eq(counts_out, counts_out_ref)) + assert torch.all(torch.eq(ngram_order, ngram_order_ref)) + + +if __name__ == '__main__': + unittest.main()