From 53520a12a6cb5232b594e222c39d36070818668f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 13 Jul 2021 19:32:49 +0800 Subject: [PATCH 01/16] Some progress on suffix arrays --- k2/csrc/CMakeLists.txt | 5 +- k2/csrc/nbest.cu | 289 +++++++++++++++++++++++++++++++++++++++++ k2/csrc/nbest.h | 204 +++++++++++++++++++++++++++++ 3 files changed, 497 insertions(+), 1 deletion(-) create mode 100644 k2/csrc/nbest.cu create mode 100644 k2/csrc/nbest.h diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index 221c2fbc9..db48d8b91 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -71,8 +71,10 @@ set(context_srcs timer.cu top_sort.cu utils.cu + nbest.cu ) + if(K2_USE_PYTORCH) list(APPEND context_srcs pytorch_context.cu) else() @@ -86,7 +88,7 @@ else() endif() # the target -add_library(context ${context_srcs}) +add_library(context ${context_srcs} ${context_cc_srcs}) target_compile_definitions(context PUBLIC K2_TORCH_VERSION_MAJOR=${K2_TORCH_VERSION_MAJOR}) target_compile_definitions(context PUBLIC K2_TORCH_VERSION_MINOR=${K2_TORCH_VERSION_MINOR}) @@ -148,6 +150,7 @@ set(cuda_test_srcs log_test.cu macros_test.cu math_test.cu + nbest_test.cu nvtx_test.cu pinned_context_test.cu ragged_shape_test.cu diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu new file mode 100644 index 000000000..028a898e8 --- /dev/null +++ b/k2/csrc/nbest.cu @@ -0,0 +1,289 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/nbest.h" + +// This is not really a CUDA file but for build-system reasons I'm currently leaving it +// with the .cu extension. + +namespace k2 { + +template +inline bool Leq(T a1, T a2, T b1, T b2) { + // lexicographic order for pairs, used in CreateSuffixArray() + return(a1 < b1 || a1 == b1 && a2 <= b2); +} +template +inline bool Leq(T a1, T a2, T a3, T b1, T b2, T b3) { + // lexicographic order for triples, used in CreateSuffixArray() + return(a1 < b1 || a1 == b1 && Leq(a2,a3, b2,b3)); +} + + +/* + Helper function for CreateSuffixArray(). + + Stably sorts a[0..n-1] to b[0..n-1] with keys in 0..K from r; + i.e. the values in a are interpreted as indexes into the array + `r` and the values in `r` are used for comparison, so that + at exit, r[b[i]] <= r[b[i+1]]. +*/ +template +static void RadixPass(const T* a, T* b, const T* r, T n, T K) { + T* c = new T[K + 1]; // counter array + for (T i = 0; i <= K; i++) c[i] = 0; // reset counters + for (T i = 0; i < n; i++) c[r[a[i]]]++; // count occurrences + for (T i = 0, sum = 0; i <= K; i++) {// exclusive prefix sums + T t = c[i]; c[i] = sum; sum += t; + } + for (T i = 0; i < n; i++) b[c[r[a[i]]]++] = a[i]; // sort + delete [] c; +} + + +// See documentation in nbest.h, where we use different names +// for the arguments (here, we leave the names the same as in +// https://algo2.iti.kit.edu/documents/jacm05-revised.pdf. +template +void CreateSuffixArray(const T* text, T n, T K, T* SA) { + if (n == 1) { // The paper's code didn't seem to handle n == 1 correctly. + SA[0] = 0; + return; + } + T n0=(n+2)/3, n1=(n+1)/3, n2=n/3, n02=n0+n2; + T* R = new T[n02 + 3]; R[n02]= R[n02+1]= R[n02+2]=0; + T* SA12 = new T[n02 + 3]; SA12[n02]=SA12[n02+1]=SA12[n02+2]=0; + T* R0 = new T[n0]; + T* SA0 = new T[n0]; + //******* Step 0: Construct sample ******** + // generate positions of mod 1 and mod 2 suffixes + // the "+(n0-n1)" adds a dummy mod 1 suffix if n%3 == 1 + for (T i=0, j=0; i < n+(n0-n1); i++) if (i%3 != 0) R[j++] = i; + //******* Step 1: Sort sample suffixes ******** + // lsb radix sort the mod 1 and mod 2 triples + RadixPass(R, SA12, text+2, n02, K); + RadixPass(SA12, R , text+1, n02, K); + RadixPass(R, SA12, text, n02, K); + + // find lexicographic names of triples and + // write them to correct places in R + T name = 0, c0 = -1, c1 = -1, c2 = -1; + for (T i = 0; i < n02; i++) { + if (text[SA12[i]] != c0 || text[SA12[i]+1] != c1 || text[SA12[i]+2] != c2) + { name++; c0 = text[SA12[i]]; c1 = text[SA12[i]+1]; c2 = text[SA12[i]+2]; } + if (SA12[i] % 3 == 1) { R[SA12[i]/3] = name; } // write to R1 + else { R[SA12[i]/3 + n0] = name; } // write to R2 + } + // recurse if names are not yet unique + if (name < n02) { + CreateSuffixArray(R, n02, name, SA12); + // store unique names in R using the suffix array + for (T i = 0; i < n02; i++) R[SA12[i]] = i + 1; + } else // generate the suffix array of R directly + for (T i = 0; i < n02; i++) SA12[R[i] - 1] = i; + //******* Step 2: Sort nonsample suffixes ******** + // stably sort the mod 0 suffixes from SA12 by their first character + for (T i=0, j=0; i < n02; i++) if (SA12[i] < n0) R0[j++] = 3*SA12[i]; + RadixPass(R0, SA0, text, n0, K); + //******* Step 3: Merge ******** + // merge sorted SA0 suffixes and sorted SA12 suffixes + for (T p=0, t=n0-n1, k=0; k < n; k++) { + // i is pos of current offset 12 suffix + T i = (SA12[t] < n0 ? SA12[t] * 3 + 1 : (SA12[t] - n0) * 3 + 2) ; + T j = SA0[p]; // pos of current offset 0 suffix + if (SA12[t] < n0 ? // different compares for mod 1 and mod 2 suffixes + Leq(text[i], R[SA12[t] + n0], text[j], R[j/3]) : + Leq(text[i],text[i+1],R[SA12[t]-n0+1], text[j],text[j+1],R[j/3+n0])) + { // suffix from SA12 is smaller + SA[k] = i; t++; + if (t == n02) // done --- only SA0 suffixes left + for (k++; p < n0; p++, k++) SA[k] = SA0[p]; + } else { // suffix from SA0 is smaller + SA[k] = j; p++; + if (p == n0) // done --- only SA12 suffixes left + for (k++; t < n02; t++, k++) + SA[k] = (SA12[t] < n0 ? SA12[t] * 3 + 1 : (SA12[t] - n0) * 3 + 2); + } + } + delete [] R; delete [] SA12; delete [] SA0; delete [] R0; +} + +// Instantiate template for int32_t and int16_t +template void CreateSuffixArray(const int32_t* text, int32_t n, + int32_t K, int32_t* SA); +template void CreateSuffixArray(const int16_t* text, int16_t n, + int16_t K, int16_t* SA); + + + +// This implements Kasai's algorithm, as summarized here +// https://people.csail.mit.edu/jshun/lcp.pdf +// (Note: there seem to be some wrong implementations of +// Kasai's algorithm online). +template +void CreateLcpArray(const T *array, + const T *suffix_array, + T seq_len, + T *lcp_array) { + Array1 inv_suffix_array(GetCpuContext(), seq_len); + T *inv_suffix_array_data = inv_suffix_array.Data(); + for (T i = 0; i < seq_len; i++) { + inv_suffix_array_data[suffix_array[i]] = i; + } + + T k = 0; + + if (seq_len > 0) + lcp_array[0] = 0; + + for (T i = 0; i < seq_len; ++i) { + T cur_rank = inv_suffix_array[i]; + if (cur_rank != 0) { + T j = suffix_array[cur_rank - 1]; + while (array[i + k] == array[j + k]) + ++k; + lcp_array[cur_rank] = k; + if (k > 0) + --k; + } + } +} + +// Instantiate template for int32_t and int16_t +template void CreateLcpArray(const int32_t *array, const int32_t *suffix_array, + int32_t seq_len, int32_t *lcp_array); +template void CreateLcpArray(const int16_t *array, const int16_t *suffix_array, + int16_t seq_len, int16_t *lcp_array); + + +template +void CreateLcpIntervalArray(ContextPtr c, + T seq_len, + T *lcp_array, + Array1 > *lcp_intervals, + Array1 *lcp_intervals_order, + Array1 *leaf_parent_intervals) { + + // + *lcp_intervals = Array1 >(c, seq_len); + LcpInterval *lcp_intervals_data = lcp_intervals->Data(); + + Array1 intervals_order(c, seq_len); + T *intervals_order_data = intervals_order.Data(); + + Array1 leaf_parent(c, seq_len); + T *leaf_parent_data = leaf_parent.Data(); + + + // This is the stack from Algorithm 1 and Algorithm 2 of + // http://www.mi.fu-berlin.de/wiki/pub/ABI/RnaSeqP4/enhanced-suffix-array.pdf + // (you can refer to the papers mentioned in the documentation in nbest.h if this link goes + // dead). + // + // The 'begin', 'last' and 'lcp' members correspond to the 'lb', 'rb' and 'lcp' + // members mentioned there; the 'parent' member is used temporarily on the stack + // to refer to the index of this LcpInterval in `lcp_intervals_data`, i.e. + // it can be interpreted as a 'self' pointer. + std::vector > stack; + + // A separate stack, of leaves of suffix tree; we maintain this so that + // we can assign the `leaf_parent` array. + std::vector leaf_stack; + + // lcp=0; begin=0; last=undefined; self=0 (interpreting the 'parent' member + // as index-of-self + T next = 0; // Will always store the next free index into `lcp_intervals_data` + T dfs_next = 0; // Will always store the next free index into `intervals_order_data`; + // this is an ordering of the indexes into `lcp_intervals_data` + // that corresponds to depth-first search. + T last_interval = -1; // Will store an index into `lcp_intervals`; this comes + // from Algorithm 2 mentioned above + stack.push_back({0, 0, seq_len, next++ }); + lcp_intervals_data[0] = stack.back(); + // We are using zero-based indexing so the code is not quite the same as our reference. + for (T i = 0; i < seq_len; ++i) { + T lb = i, lcp_array_i = lcp_array[i]; + leaf_stack.push_back(lb); + + while (lcp_array_i < stack.back().lcp) { + stack.back().last = i - 1; + last_interval = stack.back().parent; // actually, the .parent field + // currently represents 'self', + // i.e. the index of the + // lcp-interval stack.back(). + lb = stack.back().begin; + while (!leaf_stack.empty() && leaf_stack.back() >= lb) { + leaf_parent_data[leaf_stack.back()] = last_interval; + leaf_stack.pop_back(); + } + + // process(last_interval): + lcp_intervals_data[last_interval] = stack.back(); + intervals_order_data[dfs_next++] = last_interval; + stack.pop_back(); + if (lcp_array_i <= stack.back().lcp) { + // lcp_intervals_data[last_interval].parent represents the parent + // of `last_interval`; `stack.back().parent` currently represents + // the intended position of stack.back() itself, not of its parent. + lcp_intervals_data[last_interval].parent = stack.back().parent; + last_interval = -1; + } + } + if (lcp_array_i > stack.back().lcp) { + if (last_interval >= 0) { + lcp_intervals_data[last_interval].parent = next; + last_interval = -1; + } + stack.push_back({lcp_array_i, lb, -1, next++}); + } + } + assert(stack.size() == 1); + intervals_order_data[dfs_next++] = 0; + while (!leaf_stack.empty()) { + leaf_parent_data[leaf_stack.back()] = 0; + leaf_stack.pop_back(); + } + assert(dfs_next == next); + + + *lcp_intervals = lcp_intervals->Range(0, next); + if (lcp_intervals_order) + *lcp_intervals_order = intervals_order.Range(0, next); + if (leaf_parent_intervals) + *leaf_parent_intervals = leaf_parent; + +} + +// Instantiate template +template +void CreateLcpIntervalArray(ContextPtr c, + int32_t seq_len, + int32_t *lcp_array, + Array1 > *lcp_intervals, + Array1 *lcp_intervals_order, + Array1 *leaf_parent_intervals); +template +void CreateLcpIntervalArray(ContextPtr c, + int16_t seq_len, + int16_t *lcp_array, + Array1 > *lcp_intervals, + Array1 *lcp_intervals_order, + Array1 *leaf_parent_intervals); + + +} // namespace k2 diff --git a/k2/csrc/nbest.h b/k2/csrc/nbest.h new file mode 100644 index 000000000..9f6743e10 --- /dev/null +++ b/k2/csrc/nbest.h @@ -0,0 +1,204 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_CSRC_NBEST_H_ +#define K2_CSRC_NBEST_H_ + +#include +#include + +#include "k2/csrc/algorithms.h" +#include "k2/csrc/array.h" +#include "k2/csrc/log.h" +#include "k2/csrc/macros.h" +#include "k2/csrc/ragged.h" +#include "k2/csrc/utils.h" + +namespace k2 { + +// This header contains certain utility functions that are used in rescoring +// n-best lists: specifically, functions that help us select which among a set +// of n-best candidates to select for rescoring. The selection scheme is a little +// complex. It is intended to be used in a context where we do multiple successive +// rounds of n-best list rescoring, and we use the results of the 1st round +// to guide selection of candidates in the second round. So for each word +// in each n-best path that we are considering, we find the best-matching +// positions among those that we evaluated in the first round and we use those +// as inputs to a model that predicts the scores of words after n-best-list +// rescoring. +// +// Some of these functions may seem a little unrelated to n-best lists, they +// are algorithms involving suffix arrays, which we use internally in some +// algorithms we use to process n-best lists. + + + +/* + This function creates a suffix array; it is based on the + code in + https://algo2.iti.kit.edu/documents/jacm05-revised.pdf, + "Linear Work Suffix Array construction" by J. Karkkainen. + + Template args: T should be a signed integer type, we + plan to instantiate this for int32_t and int16_t only. + + @param [in] text_array Pointer to the input array of symbols + (all pointers must be CPU pointers only, for now), + whose suffixes are to be sorted. Logically this + has length `seq_len`, and symbols are required + to be in the range [1..max_symbol]. It is required + to be terminated by 3 zeros, i.e. + text_array[seq_len] == text_array[seq_len+1] == text_array[seq_len+2] == 0 + @param [in] seq_len Length of the symbol sequence (`text_array` + must be longer than this by at least 3, for termination.) + Require seq_len >= 0 + @param [out] suffix_array A pre-allocated array of length + `seq_len + 1`. At exit it will contain a permutation of + the list [ 0, 1, ... seq_len ], interpreted + as the start indexes of suffixes of `text_array`, + with the property that the sub-arrays of `text_array` + starting at these positions are lexicographically sorted. + For example, as a trivial case, if seq_len = 3 + and text_array contains [ 3, 2, 1, 0, 0, 0 ], then + `suffix_array` would contain [ 2, 1, 0 ] at exit. + @param [in] max_symbol A number that must be >= the largest + number that might be in `text_array`. The work done + is O(seq_len + max_symbol), so it is not advisable + to let max_symbol be too large. + Caution: this function allocates memory internally (although + not much more than `text_array` itself). + */ +template +void CreateSuffixArray(const T *text_array, + T seq_len, + T max_symbol, + T *suffix_array); + +/* + Create the LCP array, which is the array of lengths of longest common prefixes + (of successive pairs of suffixes). + + Template args: T should be a signed integer type, we + plan to instantiate this for int32_t and int16_t only. + + @param [in] text_array The array of symbols, of length `seq_len` plus at least + one terminating zero. The symbols should be positive + (this may not be required here, but it is rqeuired by + CreateSuffixArray()). + @param [in] suffix_array The suffix array, as created by CreateSuffixArray(); + it is a permutation of the numbers 0, 1, ... seq_len - 1. + @param [out] lcp_array An array of length `seq_len` is output to here; + it is expected to be pre-allocated. At exit, lcp_array[0] + will be 0 and lcp_array[i] for i>0 will equal the length + of the longest common prefix of + (text_array+suffix_array[i-1]) and (text_array+suffix_array[i]). +*/ +template +void CreateLcpArray(const T *text_array, + const T *suffix_array, + T seq_len, + T *lcp_array); + +/* + lcp-intervals correspond to the nodes in the suffix trie; they are a concept + used with suffix arrays, and are derived from the LCP table (see lcp_array + output of CreateLcpArray). Take care with the notation here: intervals are + "closed intervals" so [i,j] means i,i+1,...,j, i.e. the RHS is the index of + the last element, not one past the last element. + + Notation: [i,j] is an lcp-interval with lcp-value l, if: + 0 <= i < j < seq_len + lcptab[i] < l + lcptab[j+1] < l + l is the minimum of (lcptab[i], lcptab[i+1], ..., lcptab[j]) + lcp-intervals correspond to the internal nodes of the suffix trie, so + they always contain at least two children (where children can be + leaves, corresponding indexes into the suffix array, or other + lcp-intervals). + + Type LcpInterval is used to store information about the lcp interval, + which we'll later use in algorithms that traverse the suffix tree. + + + */ +template +struct LcpInterval { + // Represents the lcp-interval [begin,last] with lcp-value `lcp` + T lcp; // The lcp-value of the lcp-interval, which is the length of the + // longest prefix shared by all suffixes in this interval. + T begin; // Index of the first element + T last; // Index of the last element; we don't call this 'end' because that + // is generally used to mean one past the end. + T parent; // The parent of this lcp-interval (-1 if this is the top interval), + // in the order in which it appears in this array (of + // lcp-intervals). Note: this order is neither top-down or + // bottom-up; you can treat it as arbitrary. +}; + + +/* + Create an array of struct LcpInterval which describes the Lcp intervals + corresponding to the internal nodes of the suffix tree, and allows you + to easily run algorithms on this tree. This data structure is not + very memory-optimized and doesn't correspond to anything in the literature, + although the basic tree traversal algorithm comes from + [Mohamed Ibrahim Abouelhoda, Stefan Kurtz, Enno Ohlebusch: Replacing suffix + trees with enhanced suffix arrays. Journal of Discrete Algorithms 2 (2004) + 53-86.] and was originally adapted from [Kasai, Lee, Arimura, Arikawa, Park: + Linear-Time Longest-Common-Prefix Computation in Suffix Arrays and Its + Applications, CPM 2001]. + + The motivation here is that we are likely limited more by time than memory, + and we want a data structure that is relatively simple to use. + + @param [in] c Context pointer, used to create arrays. Required to + be a CPU context pointer for now. + @param [in] seq_len The length of the text for which we have a suffix + array + @param [in] lcp_array The LCP array, as computed by CreateLcpArray() + @param [out] lcp_intervals A *newly created* array of LcpInterval + will be written to here, of length no greater than seq_len. + @param [out] lcp_intervals_order If this is non-NULL, a newly + created array will be written to here, giving a bottom-up + order of the lcp-intervals so that each child comes before + its parent. This is a permutation of the numbers + [0,1,...lcp_intervals->Dim()-1]. + @param [out] leaf_parent_intervals If this is non-NULL, a newly + created array of size seq_len will be written to here, + saying, for each leaf in the suffix tree (corresponding to + positions in the suffix array) which lcp-interval + is its parent. Indexes into this array correspond to + indexes into the suffix array, and values correspond + to indexes into `lcp_intervals`. + */ +template +void CreateLcpIntervalArray(ContextPtr c, + T seq_len, + T *lcp_array, + Array1 > *lcp_intervals, + Array1 *lcp_intervals_order, + Array1 *leaf_parent_intervals); + + + + + + + +} +#endif // K2_CSRC_NBEST_H_ From bf20c119a14e031f9f3972e6918b442d83a827f7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 13 Jul 2021 20:38:35 +0800 Subject: [PATCH 02/16] Revert change in comment regarding seq_len+1 --- k2/csrc/nbest.cu | 4 +++- k2/csrc/nbest.h | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index 028a898e8..49addebc8 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -179,7 +179,7 @@ void CreateLcpIntervalArray(ContextPtr c, Array1 *lcp_intervals_order, Array1 *leaf_parent_intervals) { - // + *lcp_intervals = Array1 >(c, seq_len); LcpInterval *lcp_intervals_data = lcp_intervals->Data(); @@ -216,6 +216,8 @@ void CreateLcpIntervalArray(ContextPtr c, stack.push_back({0, 0, seq_len, next++ }); lcp_intervals_data[0] = stack.back(); // We are using zero-based indexing so the code is not quite the same as our reference. + // Also, http://www.mi.fu-berlin.de/wiki/pub/ABI/RnaSeqP4/enhanced-suffix-array.pdf + // seems to be expecting a suffix array of size seq_len + 1, not seq_len. for (T i = 0; i < seq_len; ++i) { T lb = i, lcp_array_i = lcp_array[i]; leaf_stack.push_back(lb); diff --git a/k2/csrc/nbest.h b/k2/csrc/nbest.h index 9f6743e10..5c117d17f 100644 --- a/k2/csrc/nbest.h +++ b/k2/csrc/nbest.h @@ -68,14 +68,16 @@ namespace k2 { must be longer than this by at least 3, for termination.) Require seq_len >= 0 @param [out] suffix_array A pre-allocated array of length - `seq_len + 1`. At exit it will contain a permutation of - the list [ 0, 1, ... seq_len ], interpreted - as the start indexes of suffixes of `text_array`, + `seq_len`. At exit it will contain a permutation of + the list [ 0, 1, ... seq_len - 1 ], interpreted + as the start indexes of the nonempty suffixes of `text_array`, with the property that the sub-arrays of `text_array` starting at these positions are lexicographically sorted. For example, as a trivial case, if seq_len = 3 and text_array contains [ 3, 2, 1, 0, 0, 0 ], then `suffix_array` would contain [ 2, 1, 0 ] at exit. + CAUTION: there is some literature on suffix arrays + that expects the suffix_array size tgo be n + 1, not n. @param [in] max_symbol A number that must be >= the largest number that might be in `text_array`. The work done is O(seq_len + max_symbol), so it is not advisable From b6b5f31e780fc5d526b013cff81ad5e37de6d105 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 13 Jul 2021 21:28:11 +0800 Subject: [PATCH 03/16] Closer to getting tests to work --- k2/csrc/nbest.cu | 4 +++- k2/csrc/nbest.h | 26 ++++++++++++++------------ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index 49addebc8..49c730c4a 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -61,6 +61,8 @@ static void RadixPass(const T* a, T* b, const T* r, T n, T K) { // https://algo2.iti.kit.edu/documents/jacm05-revised.pdf. template void CreateSuffixArray(const T* text, T n, T K, T* SA) { + //assert(text[0] <= text[n-1]); // spot check that termination symbol is larger + // than other symbols; <= in case n==1. if (n == 1) { // The paper's code didn't seem to handle n == 1 correctly. SA[0] = 0; return; @@ -223,7 +225,7 @@ void CreateLcpIntervalArray(ContextPtr c, leaf_stack.push_back(lb); while (lcp_array_i < stack.back().lcp) { - stack.back().last = i - 1; + stack.back().last = i; last_interval = stack.back().parent; // actually, the .parent field // currently represents 'self', // i.e. the index of the diff --git a/k2/csrc/nbest.h b/k2/csrc/nbest.h index 5c117d17f..f15e79597 100644 --- a/k2/csrc/nbest.h +++ b/k2/csrc/nbest.h @@ -57,13 +57,16 @@ namespace k2 { Template args: T should be a signed integer type, we plan to instantiate this for int32_t and int16_t only. - @param [in] text_array Pointer to the input array of symbols - (all pointers must be CPU pointers only, for now), - whose suffixes are to be sorted. Logically this - has length `seq_len`, and symbols are required - to be in the range [1..max_symbol]. It is required - to be terminated by 3 zeros, i.e. - text_array[seq_len] == text_array[seq_len+1] == text_array[seq_len+2] == 0 + @param [in] text_array Pointer to the input array of symbols, + including the termination symbol ($) which must be larger + than the other symbols. + All pointers must be CPU pointers only, for now. + The suffixes of this array are to be sorted. Logically this + array has length `seq_len`, and symbols are required + to be in the range [1..max_symbol]. + text_array is additionally required to be terminated by 3 zeros, + for purposes of this algorithm, i.e. + text_array[seq_len] == text_array[seq_len+1] == text_array[seq_len+2] == 0 @param [in] seq_len Length of the symbol sequence (`text_array` must be longer than this by at least 3, for termination.) Require seq_len >= 0 @@ -74,12 +77,11 @@ namespace k2 { with the property that the sub-arrays of `text_array` starting at these positions are lexicographically sorted. For example, as a trivial case, if seq_len = 3 - and text_array contains [ 3, 2, 1, 0, 0, 0 ], then - `suffix_array` would contain [ 2, 1, 0 ] at exit. - CAUTION: there is some literature on suffix arrays - that expects the suffix_array size tgo be n + 1, not n. + and text_array contains [ 3, 2, 1, 10, 0, 0, 0 ], then + `suffix_array` would contain [ 2, 1, 0, 3 ] at exit. @param [in] max_symbol A number that must be >= the largest - number that might be in `text_array`. The work done + number that might be in `text_array`, including the + termination symbol. The work done is O(seq_len + max_symbol), so it is not advisable to let max_symbol be too large. Caution: this function allocates memory internally (although From 6a8d197093508697692e1eb720df1d372e3e5e64 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 13 Jul 2021 23:01:35 +0800 Subject: [PATCH 04/16] Tests working --- k2/csrc/nbest.cu | 23 +++++++++++++++-------- k2/csrc/nbest.h | 11 +++++++---- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index 49c730c4a..e53ca8b02 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -215,22 +215,23 @@ void CreateLcpIntervalArray(ContextPtr c, // that corresponds to depth-first search. T last_interval = -1; // Will store an index into `lcp_intervals`; this comes // from Algorithm 2 mentioned above - stack.push_back({0, 0, seq_len, next++ }); + stack.push_back({0, 0, T(seq_len - 1), next++ }); lcp_intervals_data[0] = stack.back(); - // We are using zero-based indexing so the code is not quite the same as our reference. - // Also, http://www.mi.fu-berlin.de/wiki/pub/ABI/RnaSeqP4/enhanced-suffix-array.pdf - // seems to be expecting a suffix array of size seq_len + 1, not seq_len. - for (T i = 0; i < seq_len; ++i) { - T lb = i, lcp_array_i = lcp_array[i]; + lcp_intervals_data[0].parent = -1; + // We are using a numbering in which the terminating symbol $ is included + // in the array length, which is why we do "i < seq_len" and not + // "i <= seq_len" as in + // http://www.mi.fu-berlin.de/wiki/pub/ABI/RnaSeqP4/enhanced-suffix-array.pdf + for (T i = 1; i < seq_len; ++i) { + T lb = i - 1, lcp_array_i = lcp_array[i]; leaf_stack.push_back(lb); while (lcp_array_i < stack.back().lcp) { - stack.back().last = i; last_interval = stack.back().parent; // actually, the .parent field // currently represents 'self', // i.e. the index of the // lcp-interval stack.back(). - lb = stack.back().begin; + lb = stack.back().lb; while (!leaf_stack.empty() && leaf_stack.back() >= lb) { leaf_parent_data[leaf_stack.back()] = last_interval; leaf_stack.pop_back(); @@ -238,6 +239,11 @@ void CreateLcpIntervalArray(ContextPtr c, // process(last_interval): lcp_intervals_data[last_interval] = stack.back(); + // Previously tried doing: + // stack.back().rb = i - 1; + // a bit further above, but hit some kind of compiler problem, the assignment + // had no effect (back() is supposed to return a reference). + lcp_intervals_data[last_interval].rb = i - 1; intervals_order_data[dfs_next++] = last_interval; stack.pop_back(); if (lcp_array_i <= stack.back().lcp) { @@ -258,6 +264,7 @@ void CreateLcpIntervalArray(ContextPtr c, } assert(stack.size() == 1); intervals_order_data[dfs_next++] = 0; + leaf_stack.push_back(seq_len - 1); while (!leaf_stack.empty()) { leaf_parent_data[leaf_stack.back()] = 0; leaf_stack.pop_back(); diff --git a/k2/csrc/nbest.h b/k2/csrc/nbest.h index f15e79597..d1815f348 100644 --- a/k2/csrc/nbest.h +++ b/k2/csrc/nbest.h @@ -129,12 +129,16 @@ void CreateLcpArray(const T *text_array, 0 <= i < j < seq_len lcptab[i] < l lcptab[j+1] < l - l is the minimum of (lcptab[i], lcptab[i+1], ..., lcptab[j]) + l is the minimum of (lcptab[i+1], lcptab[i+2], ..., lcptab[j]) lcp-intervals correspond to the internal nodes of the suffix trie, so they always contain at least two children (where children can be leaves, corresponding indexes into the suffix array, or other lcp-intervals). + SPECIAL CASE: if seq_len is 1, which is a rather degenerate case, the above + definitions do not quite work; and we treat [0,0] as an lcp-interval with + lcp-value 0 although it does not meet the above definitions. + Type LcpInterval is used to store information about the lcp interval, which we'll later use in algorithms that traverse the suffix tree. @@ -145,9 +149,8 @@ struct LcpInterval { // Represents the lcp-interval [begin,last] with lcp-value `lcp` T lcp; // The lcp-value of the lcp-interval, which is the length of the // longest prefix shared by all suffixes in this interval. - T begin; // Index of the first element - T last; // Index of the last element; we don't call this 'end' because that - // is generally used to mean one past the end. + T lb; // Index of the first element (left boundary) + T rb; // Index of the last elemen (right boundary) T parent; // The parent of this lcp-interval (-1 if this is the top interval), // in the order in which it appears in this array (of // lcp-intervals). Note: this order is neither top-down or From fe0c0ebeb01d26ebc2a0c1ccc2b5313cbb1bc134 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 13 Jul 2021 23:26:25 +0800 Subject: [PATCH 05/16] Simplify interface of code to create LCP intervals. Internal code still needs cleanup --- k2/csrc/nbest.cu | 41 +++++++++++++++++++++++------------------ k2/csrc/nbest.h | 8 ++------ 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index e53ca8b02..6aaeed2a0 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -178,7 +178,6 @@ void CreateLcpIntervalArray(ContextPtr c, T seq_len, T *lcp_array, Array1 > *lcp_intervals, - Array1 *lcp_intervals_order, Array1 *leaf_parent_intervals) { @@ -216,8 +215,6 @@ void CreateLcpIntervalArray(ContextPtr c, T last_interval = -1; // Will store an index into `lcp_intervals`; this comes // from Algorithm 2 mentioned above stack.push_back({0, 0, T(seq_len - 1), next++ }); - lcp_intervals_data[0] = stack.back(); - lcp_intervals_data[0].parent = -1; // We are using a numbering in which the terminating symbol $ is included // in the array length, which is why we do "i < seq_len" and not // "i <= seq_len" as in @@ -231,50 +228,60 @@ void CreateLcpIntervalArray(ContextPtr c, // currently represents 'self', // i.e. the index of the // lcp-interval stack.back(). + T last_interval_dfsorder = dfs_next++; lb = stack.back().lb; while (!leaf_stack.empty() && leaf_stack.back() >= lb) { - leaf_parent_data[leaf_stack.back()] = last_interval; + leaf_parent_data[leaf_stack.back()] = last_interval_dfsorder; leaf_stack.pop_back(); } - // process(last_interval): - lcp_intervals_data[last_interval] = stack.back(); + lcp_intervals_data[last_interval_dfsorder] = stack.back(); // Previously tried doing: // stack.back().rb = i - 1; // a bit further above, but hit some kind of compiler problem, the assignment // had no effect (back() is supposed to return a reference). - lcp_intervals_data[last_interval].rb = i - 1; - intervals_order_data[dfs_next++] = last_interval; + lcp_intervals_data[last_interval_dfsorder].rb = i - 1; + intervals_order_data[last_interval] = last_interval_dfsorder; stack.pop_back(); if (lcp_array_i <= stack.back().lcp) { - // lcp_intervals_data[last_interval].parent represents the parent + // lcp_intervals_data[last_interval_dfsorder].parent represents the parent // of `last_interval`; `stack.back().parent` currently represents // the intended position of stack.back() itself, not of its parent. - lcp_intervals_data[last_interval].parent = stack.back().parent; + lcp_intervals_data[last_interval_dfsorder].parent = stack.back().parent; last_interval = -1; } } if (lcp_array_i > stack.back().lcp) { if (last_interval >= 0) { - lcp_intervals_data[last_interval].parent = next; + lcp_intervals_data[intervals_order_data[last_interval]].parent = next; last_interval = -1; } stack.push_back({lcp_array_i, lb, -1, next++}); } } assert(stack.size() == 1); - intervals_order_data[dfs_next++] = 0; + T top_node_dfsorder = dfs_next++; + lcp_intervals_data[top_node_dfsorder] = stack.back(); + lcp_intervals_data[top_node_dfsorder].parent = -1; + intervals_order_data[0] = top_node_dfsorder; leaf_stack.push_back(seq_len - 1); while (!leaf_stack.empty()) { - leaf_parent_data[leaf_stack.back()] = 0; + leaf_parent_data[leaf_stack.back()] = top_node_dfsorder; leaf_stack.pop_back(); } assert(dfs_next == next); - + for (T i = 0; i + 1 < next; i++) { + // for each lpc-interval, except the last (top) node which has -1 as its + // parent field.. Change from pushing order (order in which we pushed them + // onto the stack) to dfs post order (order in which they were popped). + lcp_intervals_data[i].parent = intervals_order_data[ + lcp_intervals_data[i].parent]; + } *lcp_intervals = lcp_intervals->Range(0, next); - if (lcp_intervals_order) - *lcp_intervals_order = intervals_order.Range(0, next); + for (T i = 0; i < next; i++) + intervals_order_data[i] = i; // We output in dfs post order now.. will + // remove this output arg. if (leaf_parent_intervals) *leaf_parent_intervals = leaf_parent; @@ -286,14 +293,12 @@ void CreateLcpIntervalArray(ContextPtr c, int32_t seq_len, int32_t *lcp_array, Array1 > *lcp_intervals, - Array1 *lcp_intervals_order, Array1 *leaf_parent_intervals); template void CreateLcpIntervalArray(ContextPtr c, int16_t seq_len, int16_t *lcp_array, Array1 > *lcp_intervals, - Array1 *lcp_intervals_order, Array1 *leaf_parent_intervals); diff --git a/k2/csrc/nbest.h b/k2/csrc/nbest.h index d1815f348..02fa758fd 100644 --- a/k2/csrc/nbest.h +++ b/k2/csrc/nbest.h @@ -180,11 +180,8 @@ struct LcpInterval { @param [in] lcp_array The LCP array, as computed by CreateLcpArray() @param [out] lcp_intervals A *newly created* array of LcpInterval will be written to here, of length no greater than seq_len. - @param [out] lcp_intervals_order If this is non-NULL, a newly - created array will be written to here, giving a bottom-up - order of the lcp-intervals so that each child comes before - its parent. This is a permutation of the numbers - [0,1,...lcp_intervals->Dim()-1]. + They will be in dfs post order. Children precede their + parents. @param [out] leaf_parent_intervals If this is non-NULL, a newly created array of size seq_len will be written to here, saying, for each leaf in the suffix tree (corresponding to @@ -198,7 +195,6 @@ void CreateLcpIntervalArray(ContextPtr c, T seq_len, T *lcp_array, Array1 > *lcp_intervals, - Array1 *lcp_intervals_order, Array1 *leaf_parent_intervals); From c921084a7f5e640a35c0ddc7d3e9729f55a1c470 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 14 Jul 2021 14:09:11 +0800 Subject: [PATCH 06/16] Tests working.. --- k2/csrc/array_ops.h | 2 +- k2/csrc/array_ops_inl.h | 2 + k2/csrc/nbest.cu | 49 +++++++++++++++ k2/csrc/nbest.h | 129 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 179 insertions(+), 3 deletions(-) diff --git a/k2/csrc/array_ops.h b/k2/csrc/array_ops.h index 1630ec9a2..d8c73e8d7 100644 --- a/k2/csrc/array_ops.h +++ b/k2/csrc/array_ops.h @@ -83,7 +83,7 @@ void ExclusiveSum(const Array1 &src, Array1 *dest) { ExclusiveSum(src.Context(), dest_dim, src.Data(), dest->Data()); } -/* wrapper for the ExclusiveSum above. Will satisfy +/* wrapper for the ExclusiveSum above (returns array with same dim as src). Will satisfy ans[i] = sum_{j=0}^{i-1} src[j] for i > 0. ans[0] is always 0. */ diff --git a/k2/csrc/array_ops_inl.h b/k2/csrc/array_ops_inl.h index a23b6e2b9..d24e9a58c 100644 --- a/k2/csrc/array_ops_inl.h +++ b/k2/csrc/array_ops_inl.h @@ -260,6 +260,7 @@ void ExclusiveSumDeref(Array1 &src, Array1 *dest) { if (dest_dim == src_dim + 1) { const RegionPtr ®ion = src.GetRegion(); ssize_t byte_offset = static_cast(src.ByteOffset()); + // If this fails: you must allocate one extra element past the end of src! K2_CHECK_GE(region->num_bytes - byte_offset, dest_dim * src.ElementSize()); } internal::PtrPtr src_data = internal::PtrPtr(src.Data()); @@ -285,6 +286,7 @@ void ExclusiveSum(const Array2 &src, Array2 *dest, int32_t axis) { if (dest_major_dim == src_major_dim + 1) { const RegionPtr ®ion = src.GetRegion(); ssize_t byte_offset = static_cast(src.ByteOffset()); + // If this fails: you must allocate one extra element past the end of src! K2_CHECK_GE(region->num_bytes - byte_offset, (src_major_dim * src_minor_dim + 1) * src.ElementSize()); } diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index 6aaeed2a0..e4f44e055 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -301,5 +301,54 @@ void CreateLcpIntervalArray(ContextPtr c, Array1 > *lcp_intervals, Array1 *leaf_parent_intervals); +template +void FindTightestNonemptyIntervals(T seq_len, + Array1 > *lcp_intervals, + Array1 *counts_exclusive_sum, + Array1 *leaf_parent_intervals) { + ContextPtr c = lcp_intervals->Context(); + K2_CHECK(counts_exclusive_sum->Dim() == seq_len + 1); + K2_CHECK(leaf_parent_intervals->Dim() == seq_len); + + const LcpInterval *lcp_intervals_data = lcp_intervals->Data(); + const T *counts_exclusive_sum_data = counts_exclusive_sum->Data(); + int32_t num_intervals = lcp_intervals->Dim(); + // `tightest_nonempty_intervals` gives, for each interval 0 <= i < num_intervals, + // the index j >= i of the tightest enclosing interval that has a nonzero + // count. As a special case, if all counts are zero, it will return the + // top (last) interval. + Array1 tightest_nonempty_interval(c, num_intervals); + T *tightest_nonempty_interval_data = tightest_nonempty_interval.Data(); + for (T i = num_intervals - 1; i >= 0; --i) { + T j; + LcpInterval cur_interval = lcp_intervals_data[i]; + if (cur_interval.parent < 0 || // top node + counts_exclusive_sum_data[cur_interval.rb + 1] > + counts_exclusive_sum_data[cur_interval.lb]) { + j = i; + } else { + // j > i, we will have already set tightest_nonempty_interval_data at this + // location. + j = tightest_nonempty_interval_data[cur_interval.parent]; + } + tightest_nonempty_interval_data[i] = j; + } + T *leaf_parent_intervals_data = leaf_parent_intervals->Data(); + for (T i = 0; i < seq_len; ++i) + leaf_parent_intervals_data[i] = tightest_nonempty_interval_data[ + leaf_parent_intervals_data[i]]; +} + +// Instantiate template +template +void FindTightestNonemptyIntervals(int32_t seq_len, + Array1 > *lcp_intervals, + Array1 *counts_exclusive_sum, + Array1 *leaf_parent_intervals); +template +void FindTightestNonemptyIntervals(int16_t seq_len, + Array1 > *lcp_intervals, + Array1 *counts_exclusive_sum, + Array1 *leaf_parent_intervals); } // namespace k2 diff --git a/k2/csrc/nbest.h b/k2/csrc/nbest.h index 02fa758fd..0c4509044 100644 --- a/k2/csrc/nbest.h +++ b/k2/csrc/nbest.h @@ -97,8 +97,8 @@ void CreateSuffixArray(const T *text_array, Create the LCP array, which is the array of lengths of longest common prefixes (of successive pairs of suffixes). - Template args: T should be a signed integer type, we - plan to instantiate this for int32_t and int16_t only. + Template args: T should be a signed integer type, we plan to instantiate this + for int32_t and int16_t only. @param [in] text_array The array of symbols, of length `seq_len` plus at least one terminating zero. The symbols should be positive @@ -119,6 +119,8 @@ void CreateLcpArray(const T *text_array, T *lcp_array); /* + Template args: T is a signed type, intended to be int16_t or int32_t + lcp-intervals correspond to the nodes in the suffix trie; they are a concept used with suffix arrays, and are derived from the LCP table (see lcp_array output of CreateLcpArray). Take care with the notation here: intervals are @@ -173,6 +175,8 @@ struct LcpInterval { The motivation here is that we are likely limited more by time than memory, and we want a data structure that is relatively simple to use. + Template args: T is a signed type, intended to be int16_t or int32_t + @param [in] c Context pointer, used to create arrays. Required to be a CPU context pointer for now. @param [in] seq_len The length of the text for which we have a suffix @@ -197,9 +201,130 @@ void CreateLcpIntervalArray(ContextPtr c, Array1 > *lcp_intervals, Array1 *leaf_parent_intervals); +/* + Modifies `leaf_parent_intervals` to give us, for each position in the suffix + array (i.e. each suffix), the tightest enclosing lcp-interval that has + nonzero count. This is used in finding the highest-order match of + a position in a text (i.e. the longest matching history). + + Template args: T is a signed type, intended to be int16_t or int32_t + + @param [in] seq_len The length of the sequence, including the terminating $. + @param [in] lcp_intervals The array of lcp intervals, as returned + by CreateLcpIntervalArray + @param [in] counts_exclusive_sum The exclusive-sum of counts of symbols in + the original text array, in the order given by the suffix + array, e.g. the original counts would have satisfied + suffix_counts[i] = counts[suffix_array[i]], and then + counts_exclusive_sum is the exclusive-sum of suffix_counts. + Must satisfy counts_exclusive_sum->Dim() == seq_len + 1. + The original counts would have been 1 for "keys" and 0 for + "queries", so an interval with nonzero difference in + counts_exclusive_sum is an interval containing at least + one key. + @param [in,out] leaf_parent_intervals At entry, this will contain, + for each leaf of the suffix tree (i.e. each position + in the suffix array) the index of the tightest enclosing + lcp-interval, i.e. an index into `lcp_intervals`. + At exit, it will contain the index of the tightest + enclosing *nonempty* lcp-interval. + */ +template +void FindTightestNonemptyIntervals(T seq_len, + Array1 > *lcp_intervals, + Array1 *counts_exclusive_sum, + Array1 *leaf_parent_intervals); + + +/* + + + For "query" sentences, this function gets the mean and variance of scores + from the best matching words-in-context in a set of of provided "key" + sentences. This matching process matches the word and the words preceding + it, looking for the highest-order match it can find (it's intended for + approximating the scores of models that see only left-context, like language + models). It is an efficient implementation using suffix arrays (done on CPU + for now, since the implementation is not very trivial). The intended + application is in estimating the scores of hypothesized transcripts, when we + have actually computed the scores for only a subset of the hypotheses. + @param [in] tokens A ragged tensor of int32_t with 2 or 3 axes (this + function recurses). If 2 axes, this represents a collection of + key and query sequences (keys have count==1, query count==0). + If 3 axes, this represents a set of such collections + and retrieval should be done independently. + 2-axis example: + [ [ the, cat, said, eos ], [ the, cat, fed, eos ] ] + 3-axis example: + [ [ [ the, cat, said, eos ], [ the, cat, fed, eos ] ], + [ [ hi, my, name, is, eos ], [ bye, my, name, is, eos ] ], ... ] + + .. where the words would actually be represented as integers, + and the eos might be -1. The eos symbol is required if this + code is to work as intended (otherwise this code will not + be able to recognize when we have reached the the beginnings + of sentences when comparing histories). bos symbols are + allowed but not required. + + @param [in] scores An array with scores.Dim() == tokens.NumElements(); + this is the item for which we are requesting best-matching + values (as means and variances in case there are multiple + best matches). In our anticipated use, these would represent + scores of words in the sentences, but they could represent + anything. + @param [in] counts An array with counts.Dim() == tokens.NumElements(), + containing 1 for words that are considered "keys" and 0 for + words that are considered "queries". Typically some entire + sentences will be keys and others will be queries. + @param [in] eos The value of the eos (end of sentence) symbol; internally, this + is used as an extra padding value before the first sentence in each + collection, so that it can act like a "bos" symbol. + @param [in] min_token The lowest possible token value, including the bos + symbol (e.g., might be -1). + @param [in] max_token The maximum possible token value. Be careful not to + set this too large the implementation contains a part which + takes time and space O(max_token - min_token). + @param [in] max_order The maximum n-gram order to ever return in the + `ngram_order` output; the output will be the minimum of max_order + and the actual order matched; or max_order if we matched all the + way to the beginning of both sentences. The main reason this is + needed is that we need a finite number to return at the + beginning of sentences. + + @param [out] mean For query positions, will contain the mean of the + scores at the best matching key positions, or zero if that is + undefined because there are no key positions at all. For key positions, + you can treat the output as being undefined (actually they + are treated the same as queries, but won't match with only + themselves because we don't match at singleton intervals). This array + will be allocated if it did not have the correct size at + entry. + @param [out] var Like `mean`, but contains the (centered) variance + of the best matching positions. + @param [out] count The number of key positions that contributed + to the `mean` and `var` statistics. This should only + be zero if `counts` was all zero. Will be allocated + if it did not have the correct size at entry. + @param [out] ngram_order The n-gram order corresponding to the + best matching positions found at each query position, up + to a maximum of `max_order`; will be `max_order` if we matched all + the way to the beginning of a sentence. Will be allocated if it + did not have the correct size at entry. +*/ +void GetBestMatchingStats(Ragged &tokens, + Array1 &scores, + Array1 &counts, + int32_t eos, + int32_t min_token, + int32_t max_token, + int32_t max_order, + Array1 *mean, + Array1 *var, + Array1 *count, + Array1 *ngram_order); From 4d0a8e43c8c1272316df945d2c7f2205fdf2ab79 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 14 Jul 2021 18:51:54 +0800 Subject: [PATCH 07/16] Add skeleton for nbest outer code (C++) --- k2/csrc/nbest.cu | 97 ++++++++++++++++++++++++++++++++++++++++++++++++ k2/csrc/nbest.h | 2 +- 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index e4f44e055..5fbc38742 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -351,4 +351,101 @@ void FindTightestNonemptyIntervals(int16_t seq_len, Array1 *counts_exclusive_sum, Array1 *leaf_parent_intervals); + + +// Internal implementation of GetBestMatchingStats(), that handles the case +// where tokens.NumAxes() == 2 and tokens.NumElements() > 0. It will +// be instantiated with int16_t if the size of the problem permits, and +// int32_t otherwise (this size is used for +template +void GetBestMatchingStatsInternal(Ragged &tokens, + Array1 &scores, + Array1 &counts, + int32_t eos, + int32_t min_token, + int32_t max_token, + int32_t max_order, + Array1 *mean, + Array1 *var, + Array1 *counts_out, + Array1 *ngram_order) { +} + + +void GetBestMatchingStats(Ragged &tokens, + Array1 &scores, + Array1 &counts, + int32_t eos, + int32_t min_token, + int32_t max_token, + int32_t max_order, + Array1 *mean, + Array1 *var, + Array1 *counts_out, + Array1 *ngram_order) { + ContextPtr c = tokens.Context(); + int32_t num_elements = tokens.NumElements(); + if (mean->Dim() != num_elements) { + *mean = Array1(c, num_elements); + } else { + K2_CHECK_EQ(mean->Dim(), 0); + } + if (var->Dim() != num_elements) { + *var = Array1(c, num_elements); + } else { + K2_CHECK_EQ(var->Dim(), 0); + } + if (counts_out->Dim() != num_elements) { + *counts_out = Array1(c, num_elements); + } else { + K2_CHECK_EQ(counts_out->Dim(), 0); + } + if (ngram_order->Dim() != num_elements) { + *ngram_order = Array1(c, num_elements); + } else { + K2_CHECK_EQ(ngram_order->Dim(), 0); + } + + K2_CHECK(eos >= min_token && eos <= max_token); + K2_CHECK_GE(max_order, 2); + K2_CHECK_EQ(num_elements, scores.Dim()); + K2_CHECK_EQ(num_elements, counts.Dim()); + + if (tokens.NumAxes() == 3) { + int32_t num_collections = tokens.Dim0(); + for (int32_t i = 0; i < num_collections; i++) { + Ragged this_tokens = tokens.Index(0, i); + int32_t begin = this_tokens.values.Data() - tokens.values.Data(), + end = begin + this_tokens.NumElements(); + Array1 this_scores = scores.Arange(begin, end), + this_mean = mean->Arange(begin, end), + this_var = var->Arange(begin, end); + Array1 this_counts = counts.Arange(begin, end), + this_counts_out = counts_out->Arange(begin, end), + this_ngram_order = ngram_order->Arange(begin, end); + + GetBestMatchingStats(this_tokens, this_scores, this_counts, eos, min_token, + max_token, max_order, + &this_mean, &this_var, + &this_counts_out, &this_ngram_order); + } + return; + } + K2_CHECK_EQ(tokens.NumAxes(), 2); // Only 2 or 3 axes are allowed. + + if (num_elements == 0) { + return; // Nothing to do. + } else if (num_elements + 10 < (1 << 15) && (max_token - min_token + 10 < (1 << 15))) { + GetBestMatchingStatsInternal(tokens, scores, counts, eos, min_token, max_token, + max_order, mean, var, counts_out, ngram_order); + } else { + GetBestMatchingStatsInternal(tokens, scores, counts, eos, min_token, max_token, + max_order, mean, var, counts_out, ngram_order); + } +} + + + + + } // namespace k2 diff --git a/k2/csrc/nbest.h b/k2/csrc/nbest.h index 0c4509044..3cb4fde9f 100644 --- a/k2/csrc/nbest.h +++ b/k2/csrc/nbest.h @@ -304,7 +304,7 @@ void FindTightestNonemptyIntervals(T seq_len, entry. @param [out] var Like `mean`, but contains the (centered) variance of the best matching positions. - @param [out] count The number of key positions that contributed + @param [out] counts_out The number of key positions that contributed to the `mean` and `var` statistics. This should only be zero if `counts` was all zero. Will be allocated if it did not have the correct size at entry. From 68ca9dc21e765f39f2954c25096ab824fb4a36c1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 14 Jul 2021 19:04:02 +0800 Subject: [PATCH 08/16] Add draft and comments for GetBestMatchingStatsInternal() --- k2/csrc/nbest.cu | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index 5fbc38742..af526267c 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -369,6 +369,40 @@ void GetBestMatchingStatsInternal(Ragged &tokens, Array1 *var, Array1 *counts_out, Array1 *ngram_order) { + + // Outline: + // First construct an array of type T which contains values as follows: + // [ tokens.values[-1]+offset, ..., tokens.values[1]+offset, tokens.values[0]+offset, eos+offset, terminator, 0, 0 ] + // where offset is 1-min_token, and terminator is max_token+1+offset. + // The 3 terminating zeros are required by CreateSuffixArray(). + // + // Create the reordered counts array `counts_reordered`, in the same order as + // the suffix array, then its exclusive sum, e.g. `counts_reordered_excsum`. + // At this point we can also create similar reordered exclusive-sums of `scores` + // and scores-squared; do these as double or roundoff will be a problem. + // + // Call CreateSuffixArray (seq_len == tokens.Dim() + 2, we include the eos and terminator). + // Call CreateLcpArray, CreateLcpIntervalArray, FindTightestNonemptyIntervals() + + // By this point we should have enough information to directly create the + // outputs : mean, var, counts_out, ngram_order. We need to be a bit careful + // about ngram_order at positions when the suffix goes up to the next eos + // (i.e. it goes to the beginning of the sentence) because the correct ngram order + // to output here is `max_order`. You will have to create an array containing + // the distance from the beginning of the sentence (can be constructed from + // the row_ids and row_splits of `tokens`), and index it with the suffix array + // to get it in the right order. + // Note: we only really care about the output at the query positions, but try + // to make it so you don't need to treat keys as a special case. + // + // Special cases/conditions to consider include: + // - make sure the `count` in the position of the eos and terminator are zero + // - various code may break if the total count over all these sentences is + // zero, so you could just detect that and treat it as a special case. + // If the total count is nonzero, it should be guaranteed that you never + // have to process an interval with zero count; FindTightestNonemptyIntervals() + // should guarantee that. + } @@ -423,7 +457,6 @@ void GetBestMatchingStats(Ragged &tokens, Array1 this_counts = counts.Arange(begin, end), this_counts_out = counts_out->Arange(begin, end), this_ngram_order = ngram_order->Arange(begin, end); - GetBestMatchingStats(this_tokens, this_scores, this_counts, eos, min_token, max_token, max_order, &this_mean, &this_var, From 5295265b03b888f636b5bcd698f9a14cd45d7168 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 16 Jul 2021 15:25:55 +0800 Subject: [PATCH 09/16] Add missing test program --- k2/csrc/nbest_test.cu | 305 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 305 insertions(+) create mode 100644 k2/csrc/nbest_test.cu diff --git a/k2/csrc/nbest_test.cu b/k2/csrc/nbest_test.cu new file mode 100644 index 000000000..8628ac9bd --- /dev/null +++ b/k2/csrc/nbest_test.cu @@ -0,0 +1,305 @@ +/** + * Copyright 2020 Xiaomi Corporation (authors: Haowen Qiu) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "k2/csrc/nbest.h" + +namespace k2 { +TEST(AlgorithmsTest, TestSuffixArray) { + ContextPtr cpu = GetCpuContext(); + + for (int i = 0; i < 100; i++) { + int array_len = RandInt(1, 50), // 1 is min, due to termination symbol. + max_symbol = RandInt(10, 500); + + Array1 array(cpu, array_len + 3); + int32_t *array_data = array.Data(); + for (int i = 0; i + 1 < array_len; i++) + array_data[i] = RandInt(1, max_symbol - 1); // termination symbol must be + // larger than all others, + // don't allow + array_data[array_len - 1] = max_symbol; // Termination symbol + + for (int i = array_len; i < array_len + 3; i++) + array_data[i] = 0; + + // really array_len + 1, extra elem is to test that it doesn't write past + // the end. + Array1 suffix_array(cpu, array_len + 2); + int32_t *suffix_array_data = suffix_array.Data(); + suffix_array_data[array_len + 1] = -10; // should not be changed. + CreateSuffixArray(array_data, array_len, + max_symbol, suffix_array_data); + assert(suffix_array_data[array_len + 1] == -10); // should be unchanged. + Array1 seen_indexes(cpu, array_len, 0); + int32_t *seen_indexes_data = seen_indexes.Data(); + for (int32_t i = 0; i < array_len; i++) + seen_indexes_data[suffix_array_data[i]] = 1; + + for (int32_t i = 0; i < array_len; i++) + assert(seen_indexes_data[i] == 1); // make sure all integers seen. + for (int32_t i = 0; i + 1 < array_len; i++) { + int32_t *suffix_a = array_data + suffix_array_data[i], + *suffix_b = array_data + suffix_array_data[i + 1]; + // checking that each suffix is lexicographically less than the next one. + // None are identical, because the terminating zero is always in different + // positions. + while (true) { + if (*suffix_a < *suffix_b) + break; // correct order + assert(!(*suffix_a > *suffix_b)); // order is wrong! + assert(!(suffix_b > array_data + array_len || + suffix_b > array_data + array_len)); // past array end without correct comparison order. + suffix_a++; + suffix_b++; + } + } + } +} + + +TEST(AlgorithmsTest, TestCreateLcpArray) { + ContextPtr cpu = GetCpuContext(); + + for (int i = 0; i < 100; i++) { + int array_len = RandInt(1, 50), // at least 1 due to termination symbol + max_symbol = RandInt(2, 5); + + Array1 array(cpu, array_len + 3); + int32_t *array_data = array.Data(); + for (int i = 0; i + 1 < array_len; i++) + array_data[i] = RandInt(1, max_symbol - 1); + array_data[array_len - 1] = max_symbol; // Termination symbol + for (int i = array_len; i < array_len + 3; i++) + array_data[i] = 0; + + Array1 suffix_array(cpu, array_len); + int32_t *suffix_array_data = suffix_array.Data(); + CreateSuffixArray(array_data, array_len, + max_symbol, suffix_array_data); + + Array1 lcp(cpu, array_len); + int32_t *lcp_data = lcp.Data(); + CreateLcpArray(array_data, suffix_array_data, array_len, + lcp_data); + if (array_len > 0) + assert(lcp_data[0] == 0); + for (int32_t i = 1; i < array_len; i++) { + int32_t lcp = lcp_data[i], + prev_pos = suffix_array_data[i - 1], + this_pos = suffix_array_data[i]; + for (int32_t j = 0; j < lcp; j++) + assert(array_data[prev_pos + j] == array_data[this_pos + j]); + assert(array_data[prev_pos + lcp] != array_data[this_pos + lcp]); + } + } +} + + +TEST(AlgorithmsTest, TestCreateLcpIntervalArray) { + ContextPtr cpu = GetCpuContext(); + + for (int i = 0; i < 100; i++) { + int array_len = RandInt(1, 50), // at least 1 due to termination symbol + max_symbol = RandInt(3, 5); + + Array1 array(cpu, array_len + 3); + int32_t *array_data = array.Data(); + for (int i = 0; i + 1 < array_len; i++) + array_data[i] = RandInt(1, max_symbol - 1); + array_data[array_len - 1] = max_symbol; // Termination symbol + for (int i = array_len; i < array_len + 3; i++) + array_data[i] = 0; + + Array1 suffix_array(cpu, array_len); + int32_t *suffix_array_data = suffix_array.Data(); + CreateSuffixArray(array_data, array_len, + max_symbol, suffix_array_data); + + Array1 lcp(cpu, array_len); + int32_t *lcp_data = lcp.Data(); + CreateLcpArray(array_data, suffix_array_data, array_len, + lcp_data); + + Array1 > lcp_intervals; + Array1 leaf_parent_intervals; + + CreateLcpIntervalArray(GetCpuContext(), + array_len, lcp_data, + &lcp_intervals, + &leaf_parent_intervals); + + LcpInterval *lcp_intervals_data = lcp_intervals.Data(); + int32_t *leaf_parent_intervals_data = leaf_parent_intervals.Data(); + int32_t num_intervals = lcp_intervals.Dim(); + for (int32_t i = 0; i < array_len; i++) { + int32_t lcp_interval = leaf_parent_intervals_data[i]; + assert(lcp_interval >= 0 && lcp_interval < num_intervals); + assert(i >= lcp_intervals_data[lcp_interval].lb && + i <= lcp_intervals_data[lcp_interval].rb); + int32_t lcp = lcp_intervals_data[lcp_interval].lcp; // the lcp value / height + + for (int32_t j = 0; j < num_intervals; j++) { + // The interval that i is a member of should be the tightest enclosing + // interval, this loop checks that. + if (lcp_intervals_data[j].lcp >= lcp && j != lcp_interval) { + assert(!(i >= lcp_intervals_data[j].lb && + i <= lcp_intervals_data[j].rb)); + } + } + } + + for (int32_t i = 0; i < num_intervals; i++) { + LcpInterval interval = lcp_intervals_data[i]; + if (!(interval.lb == 0 && interval.rb + 1 == array_len && + interval.parent == -1)) { + assert(interval.parent > i); + LcpInterval parent = lcp_intervals_data[interval.parent]; + assert(interval.lb >= parent.lb && + interval.rb <= parent.rb && + interval.lcp > parent.lcp); + } + // Now check the basic requirements/definition of lcp interval... + assert(interval.lb >= 0 && + (interval.rb > interval.lb || array_len == 1) && + interval.rb < array_len); + assert(lcp_data[interval.lb] < interval.lcp || + (interval.lb == 0 && interval.lcp == 0)); + assert(interval.rb == array_len - 1 || lcp_data[interval.rb + 1] < interval.lcp); + if (array_len != 1) { + int32_t min_lcp = 1000000; + for (int32_t j = interval.lb + 1; j <= interval.rb; ++j) + if (lcp_data[j] < min_lcp) + min_lcp = lcp_data[j]; + assert(min_lcp == interval.lcp); // Check lcp value is correct. This + // test does not work if array_len == + // 1 so we skip it in that case. + } + } + } +} + + + +TEST(AlgorithmsTest, TestFindTightestNonemptyIntervals) { + ContextPtr cpu = GetCpuContext(); + + for (int i = 0; i < 100; i++) { + int array_len = RandInt(1, 50), // at least 1 due to termination symbol + max_symbol = RandInt(3, 5); + + Array1 array(cpu, array_len + 3), + counts(cpu, array_len); + int32_t *array_data = array.Data(); + for (int i = 0; i + 1 < array_len; i++) + array_data[i] = RandInt(1, max_symbol - 1); + array_data[array_len - 1] = max_symbol; // Termination symbol + for (int i = array_len; i < array_len + 3; i++) + array_data[i] = 0; + + int32_t *counts_data = counts.Data(); + for (int i = 0; i < array_len; i++) + counts_data[i] = RandInt(0, 1); + + Array1 suffix_array_plusone(cpu, array_len + 1, 0), + suffix_array = suffix_array_plusone.Range(0, array_len); + int32_t *suffix_array_data = suffix_array.Data(); + CreateSuffixArray(array_data, array_len, + max_symbol, suffix_array_data); + + Array1 lcp(cpu, array_len); + int32_t *lcp_data = lcp.Data(); + CreateLcpArray(array_data, suffix_array_data, array_len, + lcp_data); + + Array1 > lcp_intervals; + Array1 leaf_parent_intervals; // dim will be seq_len + + CreateLcpIntervalArray(GetCpuContext(), + array_len, lcp_data, + &lcp_intervals, + &leaf_parent_intervals); + + + // we get one extra don't-care element at the end of `counts_reordered`, + // which is required by ExclusiveSum(). + Array1 counts_reordered = counts[suffix_array_plusone], + counts_reordered_sum(cpu, array_len + 1); + ExclusiveSum(counts_reordered, &counts_reordered_sum); + + + Array1 leaf_parent_intervals_mod(leaf_parent_intervals.Clone()); + + FindTightestNonemptyIntervals(array_len, + &lcp_intervals, + &counts_reordered_sum, + &leaf_parent_intervals_mod); + + LcpInterval *lcp_intervals_data = lcp_intervals.Data(); + int32_t *leaf_parent_intervals_data = leaf_parent_intervals.Data(), + *leaf_parent_intervals_mod_data = leaf_parent_intervals_mod.Data(); + + int32_t num_intervals = lcp_intervals.Dim(); + for (int32_t i = 0; i < array_len; i++) { + int32_t lcp_interval = leaf_parent_intervals_data[i], + nonempty_lcp_interval = leaf_parent_intervals_mod_data[i]; + assert(lcp_interval >= 0 && lcp_interval < num_intervals); + assert(nonempty_lcp_interval >= 0 && nonempty_lcp_interval < num_intervals); + if (counts_reordered_sum[array_len] == 0) { + // If the total count is zero, everything should go to the top of the + // tree, but we won't otherwise test this. + assert(nonempty_lcp_interval == num_intervals - 1); + } else { + int32_t lcp = lcp_intervals_data[lcp_interval].lcp; + K2_CHECK_EQ((lcp_interval == nonempty_lcp_interval), + (counts_reordered_sum[lcp_intervals_data[lcp_interval].lb] != + counts_reordered_sum[lcp_intervals_data[lcp_interval].rb + 1])); + K2_CHECK(i >= lcp_intervals_data[nonempty_lcp_interval].lb && + i <= lcp_intervals_data[nonempty_lcp_interval].rb); + + for (int32_t j = 0; j < num_intervals; j++) { + // nonempty_lcp_interval should be the tightest enclosing + // interval that has nonzero count, this loop checks that. + if (lcp_intervals_data[j].lcp >= lcp && j != nonempty_lcp_interval) { + // Check that this is not a tighter enclosing interval than + // nonempty_lcp_interval, with nonzero count, that encloses i. + K2_CHECK(!(i >= lcp_intervals_data[j].lb && + i <= lcp_intervals_data[j].rb && + counts_reordered_sum[lcp_intervals_data[j].lb] != + counts_reordered_sum[lcp_intervals_data[j].rb + 1])); + } + } + } + } + } +} + + + + + +} // namespace k2 From 5e601bd5a7d126ee70e7fd473f0b5a56ef8e7b81 Mon Sep 17 00:00:00 2001 From: pkufool Date: Sat, 17 Jul 2021 19:14:24 +0800 Subject: [PATCH 10/16] implements GetBestMatchingStatsInternal --- k2/csrc/nbest.cu | 265 ++++++++++++++++++++++++++++++++---------- k2/csrc/nbest.h | 21 ++-- k2/csrc/nbest_test.cu | 55 ++++++--- 3 files changed, 251 insertions(+), 90 deletions(-) diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index af526267c..e7aa258ad 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -1,5 +1,6 @@ /** - * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey) + * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey + * Wei Kang) * * See LICENSE for clarification regarding multiple authors * @@ -31,10 +32,9 @@ inline bool Leq(T a1, T a2, T b1, T b2) { template inline bool Leq(T a1, T a2, T a3, T b1, T b2, T b3) { // lexicographic order for triples, used in CreateSuffixArray() - return(a1 < b1 || a1 == b1 && Leq(a2,a3, b2,b3)); + return(a1 < b1 || a1 == b1 && Leq(a2, a3, b2, b3)); } - /* Helper function for CreateSuffixArray(). @@ -45,79 +45,84 @@ inline bool Leq(T a1, T a2, T a3, T b1, T b2, T b3) { */ template static void RadixPass(const T* a, T* b, const T* r, T n, T K) { - T* c = new T[K + 1]; // counter array - for (T i = 0; i <= K; i++) c[i] = 0; // reset counters - for (T i = 0; i < n; i++) c[r[a[i]]]++; // count occurrences - for (T i = 0, sum = 0; i <= K; i++) {// exclusive prefix sums + T* c = new T[K + 1]; // counter array + for (T i = 0; i <= K; i++) c[i] = 0; // reset counters + for (T i = 0; i < n; i++) c[r[a[i]]]++; // count occurrences + for (T i = 0, sum = 0; i <= K; i++) { // exclusive prefix sums T t = c[i]; c[i] = sum; sum += t; } - for (T i = 0; i < n; i++) b[c[r[a[i]]]++] = a[i]; // sort + for (T i = 0; i < n; i++) b[c[r[a[i]]]++] = a[i]; // sort delete [] c; } - // See documentation in nbest.h, where we use different names // for the arguments (here, we leave the names the same as in // https://algo2.iti.kit.edu/documents/jacm05-revised.pdf. template void CreateSuffixArray(const T* text, T n, T K, T* SA) { - //assert(text[0] <= text[n-1]); // spot check that termination symbol is larger - // than other symbols; <= in case n==1. + //assert(text[0] <= text[n-1]); // spot check that termination symbol is + // larger than other symbols; + // <= in case n==1. if (n == 1) { // The paper's code didn't seem to handle n == 1 correctly. SA[0] = 0; return; } - T n0=(n+2)/3, n1=(n+1)/3, n2=n/3, n02=n0+n2; - T* R = new T[n02 + 3]; R[n02]= R[n02+1]= R[n02+2]=0; - T* SA12 = new T[n02 + 3]; SA12[n02]=SA12[n02+1]=SA12[n02+2]=0; - T* R0 = new T[n0]; - T* SA0 = new T[n0]; + T n0 = (n + 2) / 3, n1 = (n+1) / 3, n2 = n / 3, n02 = n0 + n2; + T *R = new T[n02 + 3]; R[n02] = R[n02 + 1] = R[n02 + 2] = 0; + T *SA12 = new T[n02 + 3]; SA12[n02] = SA12[n02 + 1] = SA12[n02 + 2] = 0; + T *R0 = new T[n0]; + T *SA0 = new T[n0]; //******* Step 0: Construct sample ******** // generate positions of mod 1 and mod 2 suffixes // the "+(n0-n1)" adds a dummy mod 1 suffix if n%3 == 1 - for (T i=0, j=0; i < n+(n0-n1); i++) if (i%3 != 0) R[j++] = i; + for (T i = 0, j = 0; i < n + (n0 - n1); i++) if (i % 3 != 0) R[j++] = i; //******* Step 1: Sort sample suffixes ******** // lsb radix sort the mod 1 and mod 2 triples - RadixPass(R, SA12, text+2, n02, K); - RadixPass(SA12, R , text+1, n02, K); + RadixPass(R, SA12, text + 2, n02, K); + RadixPass(SA12, R , text + 1, n02, K); RadixPass(R, SA12, text, n02, K); // find lexicographic names of triples and // write them to correct places in R T name = 0, c0 = -1, c1 = -1, c2 = -1; for (T i = 0; i < n02; i++) { - if (text[SA12[i]] != c0 || text[SA12[i]+1] != c1 || text[SA12[i]+2] != c2) - { name++; c0 = text[SA12[i]]; c1 = text[SA12[i]+1]; c2 = text[SA12[i]+2]; } - if (SA12[i] % 3 == 1) { R[SA12[i]/3] = name; } // write to R1 - else { R[SA12[i]/3 + n0] = name; } // write to R2 + if (text[SA12[i]] != c0 || text[SA12[i] + 1] != c1 || + text[SA12[i] + 2] != c2) { + name++; + c0 = text[SA12[i]]; + c1 = text[SA12[i] + 1]; + c2 = text[SA12[i] + 2]; + } + if (SA12[i] % 3 == 1) { R[SA12[i] / 3] = name; } // write to R1 + else { R[SA12[i] / 3 + n0] = name; } // write to R2 } // recurse if names are not yet unique if (name < n02) { CreateSuffixArray(R, n02, name, SA12); // store unique names in R using the suffix array for (T i = 0; i < n02; i++) R[SA12[i]] = i + 1; - } else // generate the suffix array of R directly + } else // generate the suffix array of R directly for (T i = 0; i < n02; i++) SA12[R[i] - 1] = i; //******* Step 2: Sort nonsample suffixes ******** // stably sort the mod 0 suffixes from SA12 by their first character - for (T i=0, j=0; i < n02; i++) if (SA12[i] < n0) R0[j++] = 3*SA12[i]; + for (T i = 0, j = 0; i < n02; i++) if (SA12[i] < n0) R0[j++] = 3 * SA12[i]; RadixPass(R0, SA0, text, n0, K); //******* Step 3: Merge ******** // merge sorted SA0 suffixes and sorted SA12 suffixes - for (T p=0, t=n0-n1, k=0; k < n; k++) { + for (T p = 0, t = n0 - n1, k = 0; k < n; k++) { // i is pos of current offset 12 suffix T i = (SA12[t] < n0 ? SA12[t] * 3 + 1 : (SA12[t] - n0) * 3 + 2) ; - T j = SA0[p]; // pos of current offset 0 suffix - if (SA12[t] < n0 ? // different compares for mod 1 and mod 2 suffixes - Leq(text[i], R[SA12[t] + n0], text[j], R[j/3]) : - Leq(text[i],text[i+1],R[SA12[t]-n0+1], text[j],text[j+1],R[j/3+n0])) - { // suffix from SA12 is smaller + T j = SA0[p]; // pos of current offset 0 suffix + if (SA12[t] < n0 ? // different compares for mod 1 and mod 2 suffixes + Leq(text[i], R[SA12[t] + n0], text[j], R[j / 3]) : + Leq(text[i], text[i + 1], R[SA12[t] - n0 + 1], text[j], + text[j + 1], R[j / 3 + n0])) { // suffix from SA12 is smaller SA[k] = i; t++; - if (t == n02) // done --- only SA0 suffixes left + if (t == n02) // done --- only SA0 suffixes left for (k++; p < n0; p++, k++) SA[k] = SA0[p]; - } else { // suffix from SA0 is smaller + } else { // suffix from SA0 is smaller SA[k] = j; p++; - if (p == n0) // done --- only SA12 suffixes left + if (p == n0) // done --- only SA12 suffixes left for (k++; t < n02; t++, k++) SA[k] = (SA12[t] < n0 ? SA12[t] * 3 + 1 : (SA12[t] - n0) * 3 + 2); } @@ -131,8 +136,6 @@ template void CreateSuffixArray(const int32_t* text, int32_t n, template void CreateSuffixArray(const int16_t* text, int16_t n, int16_t K, int16_t* SA); - - // This implements Kasai's algorithm, as summarized here // https://people.csail.mit.edu/jshun/lcp.pdf // (Note: there seem to be some wrong implementations of @@ -147,9 +150,7 @@ void CreateLcpArray(const T *array, for (T i = 0; i < seq_len; i++) { inv_suffix_array_data[suffix_array[i]] = i; } - T k = 0; - if (seq_len > 0) lcp_array[0] = 0; @@ -172,7 +173,6 @@ template void CreateLcpArray(const int32_t *array, const int32_t *suffix_array, template void CreateLcpArray(const int16_t *array, const int16_t *suffix_array, int16_t seq_len, int16_t *lcp_array); - template void CreateLcpIntervalArray(ContextPtr c, T seq_len, @@ -180,7 +180,6 @@ void CreateLcpIntervalArray(ContextPtr c, Array1 > *lcp_intervals, Array1 *leaf_parent_intervals) { - *lcp_intervals = Array1 >(c, seq_len); LcpInterval *lcp_intervals_data = lcp_intervals->Data(); @@ -190,7 +189,6 @@ void CreateLcpIntervalArray(ContextPtr c, Array1 leaf_parent(c, seq_len); T *leaf_parent_data = leaf_parent.Data(); - // This is the stack from Algorithm 1 and Algorithm 2 of // http://www.mi.fu-berlin.de/wiki/pub/ABI/RnaSeqP4/enhanced-suffix-array.pdf // (you can refer to the papers mentioned in the documentation in nbest.h if this link goes @@ -274,8 +272,8 @@ void CreateLcpIntervalArray(ContextPtr c, // for each lpc-interval, except the last (top) node which has -1 as its // parent field.. Change from pushing order (order in which we pushed them // onto the stack) to dfs post order (order in which they were popped). - lcp_intervals_data[i].parent = intervals_order_data[ - lcp_intervals_data[i].parent]; + lcp_intervals_data[i].parent = + intervals_order_data[lcp_intervals_data[i].parent]; } *lcp_intervals = lcp_intervals->Range(0, next); @@ -284,7 +282,6 @@ void CreateLcpIntervalArray(ContextPtr c, // remove this output arg. if (leaf_parent_intervals) *leaf_parent_intervals = leaf_parent; - } // Instantiate template @@ -351,8 +348,6 @@ void FindTightestNonemptyIntervals(int16_t seq_len, Array1 *counts_exclusive_sum, Array1 *leaf_parent_intervals); - - // Internal implementation of GetBestMatchingStats(), that handles the case // where tokens.NumAxes() == 2 and tokens.NumElements() > 0. It will // be instantiated with int16_t if the size of the problem permits, and @@ -361,18 +356,17 @@ template void GetBestMatchingStatsInternal(Ragged &tokens, Array1 &scores, Array1 &counts, - int32_t eos, - int32_t min_token, - int32_t max_token, + T eos, + T min_token, + T max_token, int32_t max_order, Array1 *mean, Array1 *var, Array1 *counts_out, Array1 *ngram_order) { - // Outline: // First construct an array of type T which contains values as follows: - // [ tokens.values[-1]+offset, ..., tokens.values[1]+offset, tokens.values[0]+offset, eos+offset, terminator, 0, 0 ] + // [ tokens.values[-1]+offset, ..., tokens.values[1]+offset, tokens.values[0]+offset, eos+offset, terminator, 0, 0, 0 ] // where offset is 1-min_token, and terminator is max_token+1+offset. // The 3 terminating zeros are required by CreateSuffixArray(). // @@ -402,10 +396,150 @@ void GetBestMatchingStatsInternal(Ragged &tokens, // If the total count is nonzero, it should be guaranteed that you never // have to process an interval with zero count; FindTightestNonemptyIntervals() // should guarantee that. + ContextPtr &c = tokens.Context(); + T num_elements = tokens.NumElements(); + K2_CHECK_EQ(mean->Dim(), num_elements); + K2_CHECK_EQ(var->Dim(), num_elements); + K2_CHECK_EQ(counts_out->Dim(), num_elements); + K2_CHECK_EQ(ngram_order->Dim(), num_elements); + + T offset = 1 - min_token, + terminator = max_token + 1 + offset; + // we include the eos and terminator, so plus 2 here. + T seq_len = num_elements + 2; + // 3 zeros are required by CreateSuffixArray3 + Array1 text_array(c, seq_len + 3); + T *text_array_data = text_array.Data(); + const int32_t *tokens_values_data = tokens.values.Data(); + // we want to match the longest common prefix of the word and the words + // preceding it, so we need to reverse the sequence before constructing suffix + // array. + for (T i = 0; i < num_elements; ++i) { + text_array_data[i] = + tokens_values_data[num_elements - i - 1] + offset; + } + T eos_offset = eos + offset; + // fill eos, terminator and required zeros + std::vector tail({eos_offset, terminator, 0, 0, 0}); + for (T i = num_elements; i < text_array.Dim(); ++i) + text_array_data[i] = tail[i - num_elements]; + + Array1 suffix_array(c, seq_len); + CreateSuffixArray(text_array.Data(), seq_len, terminator, + suffix_array.Data()); + + // we need extra one position for `ExclusiveSum` + Array1 reorder_counts(c, seq_len + 1); + Array1 reorder_scores(c, seq_len + 1); + Array1 reorder_scores_squre(c, seq_len + 1); + T *reorder_counts_data = reorder_counts.Data(); + float *reorder_scores_data = reorder_scores.Data(); + double *reorder_scores_squre_data = reorder_scores_squre.Data(); + + const int32_t *counts_data = counts.Data(); + const float *scores_data = scores.Data(); + const T *suffix_array_data = suffix_array.Data(); + for (int32_t i = 0; i < suffix_array.Dim(); ++i) { + // we reverse the sequence above, the order of counts and scores should be + // reversed accordingly, and make ensure that the counts and scores be zero + // in the positions of eos and terminator. + int32_t rindex = seq_len - 2 - suffix_array_data[i] - 1; + reorder_counts_data[i] = rindex < 0 ? 0 : counts_data[rindex]; + reorder_scores_data[i] = rindex < 0 ? 0 : scores_data[rindex]; + reorder_scores_squre_data[i] = rindex < 0 ? 0 : + scores_data[rindex] * scores_data[rindex]; + } + ExclusiveSum(reorder_counts, &reorder_counts); + ExclusiveSum(reorder_scores, &reorder_scores); + ExclusiveSum(reorder_scores_squre, &reorder_scores_squre); + + // total count of all these sentences is zero means that there is no **keys** + // we can not match anything, return as special case. + if (reorder_counts_data[reorder_counts.Dim() - 1] == 0) { + *mean = 0; + *var = 0; + *counts_out = 0; + *ngram_order = 0; + return; + } + Array1 lpc_array(c, seq_len); + CreateLcpArray(text_array.Data(), suffix_array.Data(), seq_len, + lpc_array.Data()); + + Array1 leaf_parent_interval; + Array1 > lcp_intervals; + CreateLcpIntervalArray(c, seq_len, lpc_array.Data(), + &lcp_intervals, &leaf_parent_interval); + + FindTightestNonemptyIntervals(seq_len, &lcp_intervals, + &reorder_counts, &leaf_parent_interval); + const LcpInterval *lcp_intervals_data = lcp_intervals.Data(); + const T *leaf_parent_interval_data = leaf_parent_interval.Data(); + + Array1 dist_to_begin(c, num_elements); + T *dist_to_begin_data = dist_to_begin.Data(); + const int32_t *tokens_row_ids1_data = tokens.RowIds(1).Data(), + *tokens_row_splits1_data = tokens.RowSplits(1).Data(); + K2_EVAL( + c, num_elements, lambda_set_dist_to_begin, (int32_t idx01) + -> void { + int32_t idx0 = tokens_row_ids1_data[idx01], + idx0x = tokens_row_splits1_data[idx0], + idx1 = idx01 - idx0x; + dist_to_begin_data[idx01] = idx1 + 1; + }); + + // mapping original order to suffix array order + Array1 inv_suffix_array(c, seq_len); + T *inv_suffix_array_data = inv_suffix_array.Data(); + for (T i = 0; i < seq_len; i++) { + inv_suffix_array_data[suffix_array_data[i]] = i; + } + float *mean_data = mean->Data(), + *var_data = var->Data(); + int32_t *counts_out_data = counts_out->Data(), + *ngram_order_data = ngram_order->Data(); + + // loop in the original order + for (T i = 0; i < num_elements; ++i) { + // we reverse `tokens.values` above, minus 2 here to remove eos and + // terminator that not belongs to tokens. + T text_array_index = seq_len - 2 - i - 1; + T suffix_index = inv_suffix_array_data[text_array_index]; + // leaf_parent_interval, reorder_counts, reorder_scores are all index with + // suffix array order. + T interval_index = leaf_parent_interval_data[suffix_index]; + LcpInterval interval = lcp_intervals_data[interval_index]; + float scores_sum = reorder_scores_data[interval.rb + 1] - + reorder_scores_data[interval.lb]; + double scores_squre_sum = reorder_scores_squre_data[interval.rb + 1] - + reorder_scores_squre_data[interval.lb]; + int32_t counts_out_interval = reorder_counts_data[interval.rb + 1] - + reorder_counts_data[interval.lb]; + if (interval.lcp == 0) { // tightest interval is root interval + K2_CHECK_EQ(interval.parent, -1); + counts_out_data[i] = 0; + ngram_order_data[i] = 0; + } else { + counts_out_data[i] = counts_out_interval; + ngram_order_data[i] = std::min(interval.lcp, (T)max_order); + // handle the sentence boundary + if (dist_to_begin_data[i] <= interval.lcp) + ngram_order_data[i] = max_order; + } + mean_data[i] = counts_out_interval == 0 ? 0 : + (scores_sum / counts_out_interval); + if (counts_out_interval == 0 || counts_out_interval == 1) { + var_data[i] = 0; + } else { + double numerator = scores_squre_sum - 2 * mean_data[i] * scores_sum + + counts_out_interval * mean_data[i] * mean_data[i]; + var_data[i] = numerator / counts_out_interval; + } + } } - void GetBestMatchingStats(Ragged &tokens, Array1 &scores, Array1 &counts, @@ -417,23 +551,27 @@ void GetBestMatchingStats(Ragged &tokens, Array1 *var, Array1 *counts_out, Array1 *ngram_order) { - ContextPtr c = tokens.Context(); + ContextPtr &c = tokens.Context(); int32_t num_elements = tokens.NumElements(); + K2_CHECK(mean); if (mean->Dim() != num_elements) { *mean = Array1(c, num_elements); } else { K2_CHECK_EQ(mean->Dim(), 0); } + K2_CHECK(var); if (var->Dim() != num_elements) { *var = Array1(c, num_elements); } else { K2_CHECK_EQ(var->Dim(), 0); } + K2_CHECK(counts_out); if (counts_out->Dim() != num_elements) { *counts_out = Array1(c, num_elements); } else { K2_CHECK_EQ(counts_out->Dim(), 0); } + K2_CHECK(ngram_order); if (ngram_order->Dim() != num_elements) { *ngram_order = Array1(c, num_elements); } else { @@ -457,8 +595,8 @@ void GetBestMatchingStats(Ragged &tokens, Array1 this_counts = counts.Arange(begin, end), this_counts_out = counts_out->Arange(begin, end), this_ngram_order = ngram_order->Arange(begin, end); - GetBestMatchingStats(this_tokens, this_scores, this_counts, eos, min_token, - max_token, max_order, + GetBestMatchingStats(this_tokens, this_scores, this_counts, eos, + min_token, max_token, max_order, &this_mean, &this_var, &this_counts_out, &this_ngram_order); } @@ -468,17 +606,16 @@ void GetBestMatchingStats(Ragged &tokens, if (num_elements == 0) { return; // Nothing to do. - } else if (num_elements + 10 < (1 << 15) && (max_token - min_token + 10 < (1 << 15))) { - GetBestMatchingStatsInternal(tokens, scores, counts, eos, min_token, max_token, - max_order, mean, var, counts_out, ngram_order); + } else if (num_elements + 10 < (1 << 15) && + (max_token - min_token + 10 < (1 << 15))) { + GetBestMatchingStatsInternal(tokens, scores, counts, eos, + min_token, max_token, max_order, + mean, var, counts_out, ngram_order); } else { - GetBestMatchingStatsInternal(tokens, scores, counts, eos, min_token, max_token, - max_order, mean, var, counts_out, ngram_order); + GetBestMatchingStatsInternal(tokens, scores, counts, eos, + min_token, max_token, max_order, + mean, var, counts_out, ngram_order); } } - - - - } // namespace k2 diff --git a/k2/csrc/nbest.h b/k2/csrc/nbest.h index 3cb4fde9f..b96d43375 100644 --- a/k2/csrc/nbest.h +++ b/k2/csrc/nbest.h @@ -159,6 +159,14 @@ struct LcpInterval { // bottom-up; you can treat it as arbitrary. }; +template +std::ostream &operator<<(std::ostream &os, const LcpInterval &interval) { + static constexpr char kSep = ' '; + os << interval.lcp << kSep << interval.lb << kSep << interval.rb << kSep + << interval.parent; + return os; +} + /* Create an array of struct LcpInterval which describes the Lcp intervals @@ -235,12 +243,9 @@ void FindTightestNonemptyIntervals(T seq_len, Array1 *counts_exclusive_sum, Array1 *leaf_parent_intervals); - /* - - For "query" sentences, this function gets the mean and variance of scores - from the best matching words-in-context in a set of of provided "key" + from the best matching words-in-context in a set of provided "key" sentences. This matching process matches the word and the words preceding it, looking for the highest-order match it can find (it's intended for approximating the scores of models that see only left-context, like language @@ -249,7 +254,6 @@ void FindTightestNonemptyIntervals(T seq_len, application is in estimating the scores of hypothesized transcripts, when we have actually computed the scores for only a subset of the hypotheses. - @param [in] tokens A ragged tensor of int32_t with 2 or 3 axes (this function recurses). If 2 axes, this represents a collection of key and query sequences (keys have count==1, query count==0). @@ -265,7 +269,7 @@ void FindTightestNonemptyIntervals(T seq_len, .. where the words would actually be represented as integers, and the eos might be -1. The eos symbol is required if this code is to work as intended (otherwise this code will not - be able to recognize when we have reached the the beginnings + be able to recognize when we have reached the beginnings of sentences when comparing histories). bos symbols are allowed but not required. @@ -323,10 +327,7 @@ void GetBestMatchingStats(Ragged &tokens, int32_t max_order, Array1 *mean, Array1 *var, - Array1 *count, + Array1 *counts_out, Array1 *ngram_order); - - - } #endif // K2_CSRC_NBEST_H_ diff --git a/k2/csrc/nbest_test.cu b/k2/csrc/nbest_test.cu index 8628ac9bd..70a28b365 100644 --- a/k2/csrc/nbest_test.cu +++ b/k2/csrc/nbest_test.cu @@ -1,5 +1,6 @@ /** - * Copyright 2020 Xiaomi Corporation (authors: Haowen Qiu) + * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey + * Wei Kang) * * See LICENSE for clarification regarding multiple authors * @@ -27,6 +28,9 @@ #include #include "k2/csrc/nbest.h" +#include "k2/csrc/ragged.h" +#include "k2/csrc/ragged_ops.h" + namespace k2 { TEST(AlgorithmsTest, TestSuffixArray) { @@ -47,14 +51,14 @@ TEST(AlgorithmsTest, TestSuffixArray) { for (int i = array_len; i < array_len + 3; i++) array_data[i] = 0; - // really array_len + 1, extra elem is to test that it doesn't write past + // really array_len, extra elem is to test that it doesn't write past // the end. - Array1 suffix_array(cpu, array_len + 2); + Array1 suffix_array(cpu, array_len + 1); int32_t *suffix_array_data = suffix_array.Data(); - suffix_array_data[array_len + 1] = -10; // should not be changed. + suffix_array_data[array_len] = -10; // should not be changed. CreateSuffixArray(array_data, array_len, max_symbol, suffix_array_data); - assert(suffix_array_data[array_len + 1] == -10); // should be unchanged. + assert(suffix_array_data[array_len] == -10); // should be unchanged. Array1 seen_indexes(cpu, array_len, 0); int32_t *seen_indexes_data = seen_indexes.Data(); for (int32_t i = 0; i < array_len; i++) @@ -64,7 +68,7 @@ TEST(AlgorithmsTest, TestSuffixArray) { assert(seen_indexes_data[i] == 1); // make sure all integers seen. for (int32_t i = 0; i + 1 < array_len; i++) { int32_t *suffix_a = array_data + suffix_array_data[i], - *suffix_b = array_data + suffix_array_data[i + 1]; + *suffix_b = array_data + suffix_array_data[i + 1]; // checking that each suffix is lexicographically less than the next one. // None are identical, because the terminating zero is always in different // positions. @@ -72,7 +76,7 @@ TEST(AlgorithmsTest, TestSuffixArray) { if (*suffix_a < *suffix_b) break; // correct order assert(!(*suffix_a > *suffix_b)); // order is wrong! - assert(!(suffix_b > array_data + array_len || + assert(!(suffix_a > array_data + array_len || suffix_b > array_data + array_len)); // past array end without correct comparison order. suffix_a++; suffix_b++; @@ -81,7 +85,6 @@ TEST(AlgorithmsTest, TestSuffixArray) { } } - TEST(AlgorithmsTest, TestCreateLcpArray) { ContextPtr cpu = GetCpuContext(); @@ -119,7 +122,6 @@ TEST(AlgorithmsTest, TestCreateLcpArray) { } } - TEST(AlgorithmsTest, TestCreateLcpIntervalArray) { ContextPtr cpu = GetCpuContext(); @@ -203,8 +205,6 @@ TEST(AlgorithmsTest, TestCreateLcpIntervalArray) { } } - - TEST(AlgorithmsTest, TestFindTightestNonemptyIntervals) { ContextPtr cpu = GetCpuContext(); @@ -243,8 +243,6 @@ TEST(AlgorithmsTest, TestFindTightestNonemptyIntervals) { array_len, lcp_data, &lcp_intervals, &leaf_parent_intervals); - - // we get one extra don't-care element at the end of `counts_reordered`, // which is required by ExclusiveSum(). Array1 counts_reordered = counts[suffix_array_plusone], @@ -298,8 +296,33 @@ TEST(AlgorithmsTest, TestFindTightestNonemptyIntervals) { } } - - - +TEST(AlgorithmTest, TestGetBestMatchingStats) { + Ragged tokens(GetCpuContext(), "[ [ 4 6 7 1 8 ] [ 4 3 7 1 8 ] " + " [ 4 3 2 1 8 ] [ 5 6 7 1 8 ] ]"); + Array1 scores(GetCpuContext(), "[ 1 2 3 4 5 6 7 8 9 10 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 counts(GetCpuContext(), "[ 1 1 1 1 1 1 1 1 1 1 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 mean, var; + Array1 counts_out, ngram_order; + int32_t eos = 8, + min_token = 0, + max_token = 8, + max_order = 2; + GetBestMatchingStats(tokens, scores, counts, eos, min_token, max_token, + max_order, &mean, &var, &counts_out, &ngram_order); + Array1 mean_ref(GetCpuContext(), "[ 3.5 2 3 4 5 3.5 7 5.5 6.5 7.5 " + " 3.5 7 5.5 6.5 7.5 5.5 2 3 4 5 ]"); + Array1 var_ref(GetCpuContext(), "[ 6.25 0 0 0 0 6.25 0 6.25 6.25 6.25 " + " 6.25 0 8.25 6.25 6.25 8.25 0 0 0 0 ]"); + Array1 counts_out_ref(GetCpuContext(), "[ 2 1 1 1 1 2 1 2 2 2 " + " 2 1 0 2 2 0 1 1 1 1 ]"); + Array1 ngram_order_ref(GetCpuContext(), "[ 2 1 2 2 2 1 2 1 2 2 " + " 1 2 0 1 2 0 1 2 2 2 ]"); + K2_CHECK(Equal(mean, mean_ref)); + K2_CHECK(Equal(var, var_ref)); + K2_CHECK(Equal(counts_out, counts_out_ref)); + K2_CHECK(Equal(ngram_order, ngram_order_ref)); +} } // namespace k2 From 5093445552837d3b4122d52293f8b1db7f74b57f Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 19 Jul 2021 22:40:24 +0800 Subject: [PATCH 11/16] add GetBestMatchingStats python wrapper --- k2/csrc/nbest.cu | 12 ++-- k2/csrc/nbest_test.cu | 50 ++++++++++++--- k2/python/csrc/torch.cu | 2 + k2/python/csrc/torch/CMakeLists.txt | 1 + k2/python/csrc/torch/discounted_cum_sum.cu | 2 +- k2/python/csrc/torch/nbest.cu | 61 +++++++++++++++++++ k2/python/csrc/torch/nbest.h | 30 +++++++++ k2/python/k2/__init__.py | 4 +- k2/python/k2/utils.py | 11 ++++ k2/python/tests/CMakeLists.txt | 1 + .../tests/get_best_matching_stats_test.py | 60 ++++++++++++++++++ 11 files changed, 217 insertions(+), 17 deletions(-) create mode 100644 k2/python/csrc/torch/nbest.cu create mode 100644 k2/python/csrc/torch/nbest.h create mode 100644 k2/python/tests/get_best_matching_stats_test.py diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index e7aa258ad..bd87be189 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -19,8 +19,8 @@ #include "k2/csrc/nbest.h" -// This is not really a CUDA file but for build-system reasons I'm currently leaving it -// with the .cu extension. +// This is not really a CUDA file but for build-system reasons I'm currently +// leaving it with the .cu extension. namespace k2 { @@ -557,25 +557,25 @@ void GetBestMatchingStats(Ragged &tokens, if (mean->Dim() != num_elements) { *mean = Array1(c, num_elements); } else { - K2_CHECK_EQ(mean->Dim(), 0); + K2_CHECK_EQ(mean->Dim(), num_elements); } K2_CHECK(var); if (var->Dim() != num_elements) { *var = Array1(c, num_elements); } else { - K2_CHECK_EQ(var->Dim(), 0); + K2_CHECK_EQ(var->Dim(), num_elements); } K2_CHECK(counts_out); if (counts_out->Dim() != num_elements) { *counts_out = Array1(c, num_elements); } else { - K2_CHECK_EQ(counts_out->Dim(), 0); + K2_CHECK_EQ(counts_out->Dim(), num_elements); } K2_CHECK(ngram_order); if (ngram_order->Dim() != num_elements) { *ngram_order = Array1(c, num_elements); } else { - K2_CHECK_EQ(ngram_order->Dim(), 0); + K2_CHECK_EQ(ngram_order->Dim(), num_elements); } K2_CHECK(eos >= min_token && eos <= max_token); diff --git a/k2/csrc/nbest_test.cu b/k2/csrc/nbest_test.cu index 70a28b365..5c91fff72 100644 --- a/k2/csrc/nbest_test.cu +++ b/k2/csrc/nbest_test.cu @@ -297,6 +297,29 @@ TEST(AlgorithmsTest, TestFindTightestNonemptyIntervals) { } TEST(AlgorithmTest, TestGetBestMatchingStats) { + // There are 20 tokens, index with [0, 20) + // keys' positions are [0, 10), queries positions are [10, 20) + // The best matching positions(include the token itself) are as follows + // index 0 : (0, 5, 10) with lcp "48", we add eos(8) + // index 1 : (1, 16,) with lcp "6" + // index 2 : (2, 17,) with lcp "76" + // index 3 : (3, 18,) with lcp "671" + // index 4 : (4, 19,) with lcp "5718" + // index 5 : (5, 10,) with lcp "7184" + // index 6 : (6, 11,) with lcp "43" + // index 7 : (2, 7, 17,) with lcp "7" + // index 8 : (3, 8, 18,) with lcp "71" + // index 9 : (4, 9, 19,) with lcp "718" + // index 10 : (5, 10,) with lcp "7184" + // index 11 : (6, 11,) with lcp "43" + // index 12 : (12,) with no matching + // index 13 : (3, 8, 13, 18,) with lcp "1" + // index 14 : (4, 9, 14, 19,) with lcp "18" + // index 15 : (15,) with no matching + // index 16 : (1, 16,) with lcp "6" + // index 17 : (2, 17,) with lcp "67" + // index 18 : (3, 18,) with lcp "671" + // index 19 : (4, 19,) with lcp "6718" Ragged tokens(GetCpuContext(), "[ [ 4 6 7 1 8 ] [ 4 3 7 1 8 ] " " [ 4 3 2 1 8 ] [ 5 6 7 1 8 ] ]"); Array1 scores(GetCpuContext(), "[ 1 2 3 4 5 6 7 8 9 10 " @@ -306,23 +329,32 @@ TEST(AlgorithmTest, TestGetBestMatchingStats) { Array1 mean, var; Array1 counts_out, ngram_order; int32_t eos = 8, - min_token = 0, + min_token = 1, max_token = 8, max_order = 2; GetBestMatchingStats(tokens, scores, counts, eos, min_token, max_token, max_order, &mean, &var, &counts_out, &ngram_order); - Array1 mean_ref(GetCpuContext(), "[ 3.5 2 3 4 5 3.5 7 5.5 6.5 7.5 " - " 3.5 7 5.5 6.5 7.5 5.5 2 3 4 5 ]"); - Array1 var_ref(GetCpuContext(), "[ 6.25 0 0 0 0 6.25 0 6.25 6.25 6.25 " - " 6.25 0 8.25 6.25 6.25 8.25 0 0 0 0 ]"); - Array1 counts_out_ref(GetCpuContext(), "[ 2 1 1 1 1 2 1 2 2 2 " - " 2 1 0 2 2 0 1 1 1 1 ]"); - Array1 ngram_order_ref(GetCpuContext(), "[ 2 1 2 2 2 1 2 1 2 2 " - " 1 2 0 1 2 0 1 2 2 2 ]"); + Array1 mean_ref(GetCpuContext(), "[ 3.5 2 3 4 5 6 7 5.5 6.5 7.5 " + " 6 7 5.5 6.5 7.5 5.5 2 3 4 5 ]"); + Array1 var_ref(GetCpuContext(), "[ 6.25 0 0 0 0 0 0 6.25 6.25 6.25 " + " 0 0 8.25 6.25 6.25 8.25 0 0 0 0 ]"); + Array1 counts_out_ref(GetCpuContext(), "[ 2 1 1 1 1 1 1 2 2 2 " + " 1 1 0 2 2 0 1 1 1 1 ]"); + Array1 ngram_order_ref(GetCpuContext(), "[ 2 1 2 2 2 2 2 1 2 2 " + " 2 2 0 1 2 0 1 2 2 2 ]"); K2_CHECK(Equal(mean, mean_ref)); K2_CHECK(Equal(var, var_ref)); K2_CHECK(Equal(counts_out, counts_out_ref)); K2_CHECK(Equal(ngram_order, ngram_order_ref)); + + // max_order = 5 + // index 0, 5, 6, 10, 11 match further than the BOS boundary + max_order = 5; + GetBestMatchingStats(tokens, scores, counts, eos, min_token, max_token, + max_order, &mean, &var, &counts_out, &ngram_order); + Array1 ngram_order3_ref(GetCpuContext(), "[ 5 1 2 3 4 5 5 1 2 3 " + " 5 5 0 1 2 0 1 2 3 4 ]"); + K2_CHECK(Equal(ngram_order, ngram_order3_ref)); } } // namespace k2 diff --git a/k2/python/csrc/torch.cu b/k2/python/csrc/torch.cu index 0e6851488..46f0aa819 100644 --- a/k2/python/csrc/torch.cu +++ b/k2/python/csrc/torch.cu @@ -30,6 +30,7 @@ #include "k2/python/csrc/torch/fsa_algo.h" #include "k2/python/csrc/torch/index_add.h" #include "k2/python/csrc/torch/index_select.h" +#include "k2/python/csrc/torch/nbest.h" #include "k2/python/csrc/torch/ragged.h" #include "k2/python/csrc/torch/ragged_ops.h" @@ -40,6 +41,7 @@ void PybindTorch(py::module &m) { PybindFsaAlgo(m); PybindIndexAdd(m); PybindIndexSelect(m); + PybindNbest(m); PybindRagged(m); PybindRaggedOps(m); } diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt index 80cbe5533..9c6a82569 100644 --- a/k2/python/csrc/torch/CMakeLists.txt +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -6,6 +6,7 @@ set(torch_srcs fsa_algo.cu index_add.cu index_select.cu + nbest.cu ragged.cu ragged_ops.cu torch_util.cu diff --git a/k2/python/csrc/torch/discounted_cum_sum.cu b/k2/python/csrc/torch/discounted_cum_sum.cu index 2aba8135d..51cc9d8bb 100644 --- a/k2/python/csrc/torch/discounted_cum_sum.cu +++ b/k2/python/csrc/torch/discounted_cum_sum.cu @@ -2,7 +2,7 @@ * @brief wraps discounted_cum_sum code. * * @copyright - * Copyright 2010 Xiaomi Corp. (authors: Daniel Povey) + * Copyright 2021 Xiaomi Corp. (authors: Daniel Povey) * * @copyright * See LICENSE for clarification regarding multiple authors diff --git a/k2/python/csrc/torch/nbest.cu b/k2/python/csrc/torch/nbest.cu new file mode 100644 index 000000000..57bd89697 --- /dev/null +++ b/k2/python/csrc/torch/nbest.cu @@ -0,0 +1,61 @@ +/** + * @brief wraps nbest code. + * + * @copyright + * Copyright 2021 Xiaomi Corp. (authors: Wei Kang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/python/csrc/torch/nbest.h" + +#include "k2/csrc/context.h" +#include "k2/csrc/device_guard.h" +#include "k2/csrc/macros.h" +#include "k2/csrc/nbest.h" +#include "k2/csrc/nvtx.h" +#include "k2/csrc/tensor_ops.h" +#include "k2/python/csrc/torch/torch_util.h" + +namespace k2 { + +static void PybindGetBestMatchingStats(py::module &m) { + m.def( + "get_best_matching_stats", + [](Ragged &tokens, torch::Tensor scores, torch::Tensor counts, + int32_t eos, int32_t min_token, int32_t max_token, + int32_t max_order) -> std::tuple { + DeviceGuard guard(tokens.Context()); + Array1 scores_array = FromTorch(scores); + Array1 counts_array = FromTorch(counts); + Array1 mean, var; + Array1 counts_out, ngram_order; + GetBestMatchingStats(tokens, scores_array, counts_array, + eos, min_token, max_token, max_order, + &mean, &var, &counts_out, &ngram_order); + return std::make_tuple(ToTorch(mean), ToTorch(var), + ToTorch(counts_out), ToTorch(ngram_order)); + }, + py::arg("tokens"), py::arg("scores"), py::arg("counts"), py::arg("eos"), + py::arg("min_token"), py::arg("max_token"), py::arg("max_order")); +} + +} // namespace k2 + +void PybindNbest(py::module &m) { + k2::PybindGetBestMatchingStats(m); +} \ No newline at end of file diff --git a/k2/python/csrc/torch/nbest.h b/k2/python/csrc/torch/nbest.h new file mode 100644 index 000000000..a6d5e805e --- /dev/null +++ b/k2/python/csrc/torch/nbest.h @@ -0,0 +1,30 @@ +/** + * @brief python wrapper for nbest.h + * + * @copyright + * Copyright 2021 Xiaomi Corp. (authors: Wei Kang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_PYTHON_CSRC_TORCH_NBEST_H_ +#define K2_PYTHON_CSRC_TORCH_NBEST_H_ + +#include "k2/python/csrc/k2.h" + +void PybindNbest(py::module &m); + +#endif // K2_PYTHON_CSRC_TORCH_NBEST_H_ diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index d7daf6bb8..58fac68b8 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -49,10 +49,12 @@ from .ragged import RaggedShape from .symbol_table import SymbolTable from .utils import create_fsa_vec +from .utils import create_sparse from .utils import is_rand_equivalent +from .utils import get_best_matching_stats from .utils import to_dot from .utils import to_str from .utils import to_str_simple from .utils import to_tensor -from .utils import create_sparse + from _k2.version import with_cuda diff --git a/k2/python/k2/utils.py b/k2/python/k2/utils.py index b22ed77e6..dc143cabe 100644 --- a/k2/python/k2/utils.py +++ b/k2/python/k2/utils.py @@ -514,3 +514,14 @@ def fsa_from_unary_function_ragged(src: Fsa, dest_arcs: _k2.RaggedArc, k2.autograd_utils.phantom_index_and_sum_scores(dest, src.scores, arc_map) return dest + + +def get_best_matching_stats(tokens: _k2.RaggedInt, scores: torch.Tensor, + counts: torch.Tensor, eos: int, min_token: int, + max_token: int,max_order: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ''' + + ''' + return _k2.get_best_matching_stats(tokens, scores, counts, eos, + min_token, max_token, max_order) \ No newline at end of file diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 8c6e09740..7828e8edb 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -35,6 +35,7 @@ set(py_test_files get_arc_post_test.py get_backward_scores_test.py get_forward_scores_test.py + get_best_matching_stats_test.py get_tot_scores_test.py index_add_test.py index_and_sum_test.py diff --git a/k2/python/tests/get_best_matching_stats_test.py b/k2/python/tests/get_best_matching_stats_test.py new file mode 100644 index 000000000..9430c577e --- /dev/null +++ b/k2/python/tests/get_best_matching_stats_test.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (author: Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R get_best_matching_stats_test_py + +import unittest + +import k2 +import torch + + +class TestGetBestMatchingStats(unittest.TestCase): + + def test(self): + s = '[ [ [ 5 1 4 6 ] [ 5 1 2 6 ] [ 5 3 4 6 ] ] ]' + tokens = k2.RaggedInt(s) + scores = torch.tensor([1, 2, 3, 4, 5, 7, 8, 6, 0, 0, 0, 0], + dtype=torch.float32) + counts = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + dtype=torch.int32) + eos = 6 + min_token = 1 + max_token = 6 + max_order = 2 + mean, var, counts_out, ngram_order = k2.get_best_matching_stats( + tokens, scores, counts, eos, min_token, max_token, max_order) + + mean_ref = torch.tensor([3, 4.5, 3, 4, 3, 4.5, 4.5, 5, 3, 4.5, 3, 4], + dtype=torch.float32) + var_ref = torch.tensor([4, 6.25, 0, 0, 4, 6.25, 5.25, 1, 4, 5.25, 0, 0], + dtype=torch.float32) + counts_out_ref = torch.tensor([2, 2, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1], + dtype=torch.int32) + ngram_order_ref = torch.tensor([2, 2, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2], + dtype=torch.int32) + assert torch.allclose(mean, mean_ref) + assert torch.allclose(var, var_ref) + assert torch.all(torch.eq(counts_out, counts_out_ref)) + assert torch.all(torch.eq(ngram_order, ngram_order_ref)) + + +if __name__ == '__main__': + unittest.main() From dc70241cb3a82f77e8f7ae156a4d2371e0a950d9 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 20 Jul 2021 10:18:51 +0800 Subject: [PATCH 12/16] Fix code style & refine unit test --- k2/csrc/nbest.cu | 126 +++++++++++++++++++++++------------------- k2/csrc/nbest.h | 34 ++++++------ k2/csrc/nbest_test.cu | 91 ++++++++++++++++++++++++------ k2/python/k2/utils.py | 87 +++++++++++++++++++++++++++-- 4 files changed, 241 insertions(+), 97 deletions(-) diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index bd87be189..7c1a30e98 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -17,6 +17,7 @@ * limitations under the License. */ +#include #include "k2/csrc/nbest.h" // This is not really a CUDA file but for build-system reasons I'm currently @@ -60,7 +61,7 @@ static void RadixPass(const T* a, T* b, const T* r, T n, T K) { // https://algo2.iti.kit.edu/documents/jacm05-revised.pdf. template void CreateSuffixArray(const T* text, T n, T K, T* SA) { - //assert(text[0] <= text[n-1]); // spot check that termination symbol is + // assert(text[0] <= text[n-1]); // spot check that termination symbol is // larger than other symbols; // <= in case n==1. if (n == 1) { // The paper's code didn't seem to handle n == 1 correctly. @@ -105,18 +106,19 @@ void CreateSuffixArray(const T* text, T n, T K, T* SA) { for (T i = 0; i < n02; i++) SA12[R[i] - 1] = i; //******* Step 2: Sort nonsample suffixes ******** // stably sort the mod 0 suffixes from SA12 by their first character - for (T i = 0, j = 0; i < n02; i++) if (SA12[i] < n0) R0[j++] = 3 * SA12[i]; + for (T i = 0, j = 0; i < n02; i++) + if (SA12[i] < n0) R0[j++] = 3 * SA12[i]; RadixPass(R0, SA0, text, n0, K); //******* Step 3: Merge ******** // merge sorted SA0 suffixes and sorted SA12 suffixes for (T p = 0, t = n0 - n1, k = 0; k < n; k++) { // i is pos of current offset 12 suffix - T i = (SA12[t] < n0 ? SA12[t] * 3 + 1 : (SA12[t] - n0) * 3 + 2) ; + T i = (SA12[t] < n0 ? SA12[t] * 3 + 1 : (SA12[t] - n0) * 3 + 2); T j = SA0[p]; // pos of current offset 0 suffix if (SA12[t] < n0 ? // different compares for mod 1 and mod 2 suffixes Leq(text[i], R[SA12[t] + n0], text[j], R[j / 3]) : Leq(text[i], text[i + 1], R[SA12[t] - n0 + 1], text[j], - text[j + 1], R[j / 3 + n0])) { // suffix from SA12 is smaller + text[j + 1], R[j / 3 + n0])) { // suffix from SA12 is smaller SA[k] = i; t++; if (t == n02) // done --- only SA0 suffixes left for (k++; p < n0; p++, k++) SA[k] = SA0[p]; @@ -179,7 +181,6 @@ void CreateLcpIntervalArray(ContextPtr c, T *lcp_array, Array1 > *lcp_intervals, Array1 *leaf_parent_intervals) { - *lcp_intervals = Array1 >(c, seq_len); LcpInterval *lcp_intervals_data = lcp_intervals->Data(); @@ -191,13 +192,13 @@ void CreateLcpIntervalArray(ContextPtr c, // This is the stack from Algorithm 1 and Algorithm 2 of // http://www.mi.fu-berlin.de/wiki/pub/ABI/RnaSeqP4/enhanced-suffix-array.pdf - // (you can refer to the papers mentioned in the documentation in nbest.h if this link goes - // dead). + // (you can refer to the papers mentioned in the documentation in nbest.h + // if this link goes dead). // - // The 'begin', 'last' and 'lcp' members correspond to the 'lb', 'rb' and 'lcp' - // members mentioned there; the 'parent' member is used temporarily on the stack - // to refer to the index of this LcpInterval in `lcp_intervals_data`, i.e. - // it can be interpreted as a 'self' pointer. + // The 'begin', 'last' and 'lcp' members correspond to the 'lb', 'rb' and + // 'lcp' members mentioned there; the 'parent' member is used temporarily + // on the stack to refer to the index of this LcpInterval in + // `lcp_intervals_data`, i.e. it can be interpreted as a 'self' pointer. std::vector > stack; // A separate stack, of leaves of suffix tree; we maintain this so that @@ -206,12 +207,14 @@ void CreateLcpIntervalArray(ContextPtr c, // lcp=0; begin=0; last=undefined; self=0 (interpreting the 'parent' member // as index-of-self - T next = 0; // Will always store the next free index into `lcp_intervals_data` - T dfs_next = 0; // Will always store the next free index into `intervals_order_data`; - // this is an ordering of the indexes into `lcp_intervals_data` - // that corresponds to depth-first search. - T last_interval = -1; // Will store an index into `lcp_intervals`; this comes - // from Algorithm 2 mentioned above + // Will always store the next free index into `lcp_intervals_data` + T next = 0; + // Will always store the next free index into `intervals_order_data`; + // this is an ordering of the indexes into `lcp_intervals_data` that + // corresponds to depth-first search. + T dfs_next = 0; + T last_interval = -1; // Will store an index into `lcp_intervals`; this + // comes from Algorithm 2 mentioned above stack.push_back({0, 0, T(seq_len - 1), next++ }); // We are using a numbering in which the terminating symbol $ is included // in the array length, which is why we do "i < seq_len" and not @@ -234,18 +237,21 @@ void CreateLcpIntervalArray(ContextPtr c, } // process(last_interval): lcp_intervals_data[last_interval_dfsorder] = stack.back(); - // Previously tried doing: + // Previously tried doing: // stack.back().rb = i - 1; - // a bit further above, but hit some kind of compiler problem, the assignment - // had no effect (back() is supposed to return a reference). + // a bit further above, but hit some kind of compiler problem, + // the assignment had no effect (back() is supposed to return a + // reference). lcp_intervals_data[last_interval_dfsorder].rb = i - 1; intervals_order_data[last_interval] = last_interval_dfsorder; stack.pop_back(); if (lcp_array_i <= stack.back().lcp) { - // lcp_intervals_data[last_interval_dfsorder].parent represents the parent - // of `last_interval`; `stack.back().parent` currently represents - // the intended position of stack.back() itself, not of its parent. - lcp_intervals_data[last_interval_dfsorder].parent = stack.back().parent; + // lcp_intervals_data[last_interval_dfsorder].parent represents + // the parent of `last_interval`; `stack.back().parent` currently + // represents the intended position of stack.back() itself, + // not of its parent. + lcp_intervals_data[last_interval_dfsorder].parent = + stack.back().parent; last_interval = -1; } } @@ -310,10 +316,10 @@ void FindTightestNonemptyIntervals(T seq_len, const LcpInterval *lcp_intervals_data = lcp_intervals->Data(); const T *counts_exclusive_sum_data = counts_exclusive_sum->Data(); int32_t num_intervals = lcp_intervals->Dim(); - // `tightest_nonempty_intervals` gives, for each interval 0 <= i < num_intervals, - // the index j >= i of the tightest enclosing interval that has a nonzero - // count. As a special case, if all counts are zero, it will return the - // top (last) interval. + // `tightest_nonempty_intervals` gives, for each interval + // 0 <= i < num_intervals, the index j >= i of the tightest enclosing + // interval that has a nonzero count. As a special case, if all counts + // are zero, it will return the top (last) interval. Array1 tightest_nonempty_interval(c, num_intervals); T *tightest_nonempty_interval_data = tightest_nonempty_interval.Data(); for (T i = num_intervals - 1; i >= 0; --i) { @@ -324,8 +330,8 @@ void FindTightestNonemptyIntervals(T seq_len, counts_exclusive_sum_data[cur_interval.lb]) { j = i; } else { - // j > i, we will have already set tightest_nonempty_interval_data at this - // location. + // j > i, we will have already set tightest_nonempty_interval_data + // at this location. j = tightest_nonempty_interval_data[cur_interval.parent]; } tightest_nonempty_interval_data[i] = j; @@ -339,12 +345,12 @@ void FindTightestNonemptyIntervals(T seq_len, // Instantiate template template void FindTightestNonemptyIntervals(int32_t seq_len, - Array1 > *lcp_intervals, + Array1 > *lcp_intervals, // NOLINT Array1 *counts_exclusive_sum, Array1 *leaf_parent_intervals); template void FindTightestNonemptyIntervals(int16_t seq_len, - Array1 > *lcp_intervals, + Array1 > *lcp_intervals, // NOLINT Array1 *counts_exclusive_sum, Array1 *leaf_parent_intervals); @@ -366,36 +372,42 @@ void GetBestMatchingStatsInternal(Ragged &tokens, Array1 *ngram_order) { // Outline: // First construct an array of type T which contains values as follows: - // [ tokens.values[-1]+offset, ..., tokens.values[1]+offset, tokens.values[0]+offset, eos+offset, terminator, 0, 0, 0 ] - // where offset is 1-min_token, and terminator is max_token+1+offset. - // The 3 terminating zeros are required by CreateSuffixArray(). + // [ tokens.values[-1]+offset, ..., tokens.values[1]+offset, + // tokens.values[0]+offset, eos+offset, terminator, 0, 0, 0 ] + // where offset is 1-min_token, and terminator is max_token+1+offset. + // The 3 terminating zeros are required by CreateSuffixArray(). // - // Create the reordered counts array `counts_reordered`, in the same order as - // the suffix array, then its exclusive sum, e.g. `counts_reordered_excsum`. - // At this point we can also create similar reordered exclusive-sums of `scores` - // and scores-squared; do these as double or roundoff will be a problem. + // Call CreateSuffixArray (seq_len == tokens.Dim() + 2, we include the + // eos and terminator). + // + // Create the reordered counts array `counts_reordered`, in the same order + // as the suffix array, then its exclusive sum, + // e.g. `counts_reordered_excsum`. At this point we can also create similar + // reordered exclusive-sums of `scores` and scores-squared; + // do these as double or roundoff will be a problem. + // + // Call CreateLcpArray, CreateLcpIntervalArray, + // FindTightestNonemptyIntervals // - // Call CreateSuffixArray (seq_len == tokens.Dim() + 2, we include the eos and terminator). - // Call CreateLcpArray, CreateLcpIntervalArray, FindTightestNonemptyIntervals() - // By this point we should have enough information to directly create the - // outputs : mean, var, counts_out, ngram_order. We need to be a bit careful - // about ngram_order at positions when the suffix goes up to the next eos - // (i.e. it goes to the beginning of the sentence) because the correct ngram order - // to output here is `max_order`. You will have to create an array containing - // the distance from the beginning of the sentence (can be constructed from - // the row_ids and row_splits of `tokens`), and index it with the suffix array - // to get it in the right order. - // Note: we only really care about the output at the query positions, but try - // to make it so you don't need to treat keys as a special case. + // outputs : mean, var, counts_out, ngram_order. We need to be a bit + // careful about ngram_order at positions when the suffix goes up to the + // next eos (i.e. it goes to the beginning of the sentence) because the + // correct ngram order to output here is `max_order`. You will have to + // create an array containing the distance from the beginning of the + // sentence (can be constructed from the row_ids and row_splits of `tokens`) + // + // Note: we only really care about the output at the query positions, but + // try to make it so you don't need to treat keys as a special case. // // Special cases/conditions to consider include: - // - make sure the `count` in the position of the eos and terminator are zero + // - make sure the `count` in the position of the eos and terminator + // are zero // - various code may break if the total count over all these sentences is // zero, so you could just detect that and treat it as a special case. // If the total count is nonzero, it should be guaranteed that you never - // have to process an interval with zero count; FindTightestNonemptyIntervals() - // should guarantee that. + // have to process an interval with zero count; + // FindTightestNonemptyIntervals() should guarantee that. ContextPtr &c = tokens.Context(); T num_elements = tokens.NumElements(); K2_CHECK_EQ(mean->Dim(), num_elements); @@ -412,8 +424,8 @@ void GetBestMatchingStatsInternal(Ragged &tokens, T *text_array_data = text_array.Data(); const int32_t *tokens_values_data = tokens.values.Data(); // we want to match the longest common prefix of the word and the words - // preceding it, so we need to reverse the sequence before constructing suffix - // array. + // preceding it, so we need to reverse the sequence before constructing + // suffix array. for (T i = 0; i < num_elements; ++i) { text_array_data[i] = tokens_values_data[num_elements - i - 1] + offset; @@ -523,7 +535,7 @@ void GetBestMatchingStatsInternal(Ragged &tokens, ngram_order_data[i] = 0; } else { counts_out_data[i] = counts_out_interval; - ngram_order_data[i] = std::min(interval.lcp, (T)max_order); + ngram_order_data[i] = min(interval.lcp, (T)max_order); // handle the sentence boundary if (dist_to_begin_data[i] <= interval.lcp) ngram_order_data[i] = max_order; diff --git a/k2/csrc/nbest.h b/k2/csrc/nbest.h index b96d43375..acd34eb19 100644 --- a/k2/csrc/nbest.h +++ b/k2/csrc/nbest.h @@ -33,21 +33,19 @@ namespace k2 { // This header contains certain utility functions that are used in rescoring // n-best lists: specifically, functions that help us select which among a set -// of n-best candidates to select for rescoring. The selection scheme is a little -// complex. It is intended to be used in a context where we do multiple successive -// rounds of n-best list rescoring, and we use the results of the 1st round -// to guide selection of candidates in the second round. So for each word -// in each n-best path that we are considering, we find the best-matching -// positions among those that we evaluated in the first round and we use those -// as inputs to a model that predicts the scores of words after n-best-list -// rescoring. +// of n-best candidates to select for rescoring. The selection scheme is a +// little complex. It is intended to be used in a context where we do multiple +// successive rounds of n-best list rescoring, and we use the results of the +// 1st round to guide selection of candidates in the second round. +// So for each word in each n-best path that we are considering, we find the +// best-matching positions among those that we evaluated in the first round and +// we use those as inputs to a model that predicts the scores of words after +// n-best-list rescoring. // // Some of these functions may seem a little unrelated to n-best lists, they // are algorithms involving suffix arrays, which we use internally in some // algorithms we use to process n-best lists. - - /* This function creates a suffix array; it is based on the code in @@ -143,8 +141,6 @@ void CreateLcpArray(const T *text_array, Type LcpInterval is used to store information about the lcp interval, which we'll later use in algorithms that traverse the suffix tree. - - */ template struct LcpInterval { @@ -153,10 +149,11 @@ struct LcpInterval { // longest prefix shared by all suffixes in this interval. T lb; // Index of the first element (left boundary) T rb; // Index of the last elemen (right boundary) - T parent; // The parent of this lcp-interval (-1 if this is the top interval), - // in the order in which it appears in this array (of - // lcp-intervals). Note: this order is neither top-down or - // bottom-up; you can treat it as arbitrary. + T parent; // The parent of this lcp-interval + // (-1 if this is the top interval), + // in the order in which it appears in this array (of + // lcp-intervals). Note: this order is neither top-down or + // bottom-up; you can treat it as arbitrary. }; template @@ -217,7 +214,8 @@ void CreateLcpIntervalArray(ContextPtr c, Template args: T is a signed type, intended to be int16_t or int32_t - @param [in] seq_len The length of the sequence, including the terminating $. + @param [in] seq_len The length of the sequence, + including the terminating $. @param [in] lcp_intervals The array of lcp intervals, as returned by CreateLcpIntervalArray @param [in] counts_exclusive_sum The exclusive-sum of counts of symbols in @@ -329,5 +327,5 @@ void GetBestMatchingStats(Ragged &tokens, Array1 *var, Array1 *counts_out, Array1 *ngram_order); -} +} // namespace k2 #endif // K2_CSRC_NBEST_H_ diff --git a/k2/csrc/nbest_test.cu b/k2/csrc/nbest_test.cu index 5c91fff72..e67446b48 100644 --- a/k2/csrc/nbest_test.cu +++ b/k2/csrc/nbest_test.cu @@ -43,9 +43,9 @@ TEST(AlgorithmsTest, TestSuffixArray) { Array1 array(cpu, array_len + 3); int32_t *array_data = array.Data(); for (int i = 0; i + 1 < array_len; i++) - array_data[i] = RandInt(1, max_symbol - 1); // termination symbol must be - // larger than all others, - // don't allow + array_data[i] = RandInt(1, max_symbol - 1); // termination symbol must + // be larger than all + // others, don't allow array_data[array_len - 1] = max_symbol; // Termination symbol for (int i = array_len; i < array_len + 3; i++) @@ -74,10 +74,11 @@ TEST(AlgorithmsTest, TestSuffixArray) { // positions. while (true) { if (*suffix_a < *suffix_b) - break; // correct order + break; // correct order assert(!(*suffix_a > *suffix_b)); // order is wrong! + // past array end without correct comparison order. assert(!(suffix_a > array_data + array_len || - suffix_b > array_data + array_len)); // past array end without correct comparison order. + suffix_b > array_data + array_len)); suffix_a++; suffix_b++; } @@ -163,7 +164,8 @@ TEST(AlgorithmsTest, TestCreateLcpIntervalArray) { assert(lcp_interval >= 0 && lcp_interval < num_intervals); assert(i >= lcp_intervals_data[lcp_interval].lb && i <= lcp_intervals_data[lcp_interval].rb); - int32_t lcp = lcp_intervals_data[lcp_interval].lcp; // the lcp value / height + // the lcp value / height + int32_t lcp = lcp_intervals_data[lcp_interval].lcp; for (int32_t j = 0; j < num_intervals; j++) { // The interval that i is a member of should be the tightest enclosing @@ -191,7 +193,8 @@ TEST(AlgorithmsTest, TestCreateLcpIntervalArray) { interval.rb < array_len); assert(lcp_data[interval.lb] < interval.lcp || (interval.lb == 0 && interval.lcp == 0)); - assert(interval.rb == array_len - 1 || lcp_data[interval.rb + 1] < interval.lcp); + assert(interval.rb == array_len - 1 || + lcp_data[interval.rb + 1] < interval.lcp); if (array_len != 1) { int32_t min_lcp = 1000000; for (int32_t j = interval.lb + 1; j <= interval.rb; ++j) @@ -266,7 +269,8 @@ TEST(AlgorithmsTest, TestFindTightestNonemptyIntervals) { int32_t lcp_interval = leaf_parent_intervals_data[i], nonempty_lcp_interval = leaf_parent_intervals_mod_data[i]; assert(lcp_interval >= 0 && lcp_interval < num_intervals); - assert(nonempty_lcp_interval >= 0 && nonempty_lcp_interval < num_intervals); + assert(nonempty_lcp_interval >= 0 && + nonempty_lcp_interval < num_intervals); if (counts_reordered_sum[array_len] == 0) { // If the total count is zero, everything should go to the top of the // tree, but we won't otherwise test this. @@ -274,8 +278,8 @@ TEST(AlgorithmsTest, TestFindTightestNonemptyIntervals) { } else { int32_t lcp = lcp_intervals_data[lcp_interval].lcp; K2_CHECK_EQ((lcp_interval == nonempty_lcp_interval), - (counts_reordered_sum[lcp_intervals_data[lcp_interval].lb] != - counts_reordered_sum[lcp_intervals_data[lcp_interval].rb + 1])); + (counts_reordered_sum[lcp_intervals_data[lcp_interval].lb] != // NOLINT + counts_reordered_sum[lcp_intervals_data[lcp_interval].rb + 1])); // NOLINT K2_CHECK(i >= lcp_intervals_data[nonempty_lcp_interval].lb && i <= lcp_intervals_data[nonempty_lcp_interval].rb); @@ -296,11 +300,30 @@ TEST(AlgorithmsTest, TestFindTightestNonemptyIntervals) { } } -TEST(AlgorithmTest, TestGetBestMatchingStats) { +TEST(AlgorithmTest, TestGetBestMatchingStatsEmpty) { + Ragged tokens(GetCpuContext(), "[ [ [ ] ] ]"); + Array1 scores(GetCpuContext(), "[ ]"); + Array1 counts(GetCpuContext(), "[ ]"); + Array1 mean, var; + Array1 counts_out, ngram_order; + int32_t eos = 8, + min_token = 1, + max_token = 8, + max_order = 2; + GetBestMatchingStats(tokens, scores, counts, eos, min_token, max_token, + max_order, &mean, &var, &counts_out, &ngram_order); + + K2_CHECK_EQ(mean.Dim(), 0); + K2_CHECK_EQ(var.Dim(), 0); + K2_CHECK_EQ(counts_out.Dim(), 0); + K2_CHECK_EQ(ngram_order.Dim(), 0); +} + +TEST(AlgorithmTest, TestGetBestMatchingStatsSingle) { // There are 20 tokens, index with [0, 20) // keys' positions are [0, 10), queries positions are [10, 20) // The best matching positions(include the token itself) are as follows - // index 0 : (0, 5, 10) with lcp "48", we add eos(8) + // index 0 : (0, 5, 10) with lcp "84", we add eos(8) // index 1 : (1, 16,) with lcp "6" // index 2 : (2, 17,) with lcp "76" // index 3 : (3, 18,) with lcp "671" @@ -346,15 +369,47 @@ TEST(AlgorithmTest, TestGetBestMatchingStats) { K2_CHECK(Equal(var, var_ref)); K2_CHECK(Equal(counts_out, counts_out_ref)); K2_CHECK(Equal(ngram_order, ngram_order_ref)); +} - // max_order = 5 - // index 0, 5, 6, 10, 11 match further than the BOS boundary - max_order = 5; +TEST(AlgorithmTest, TestGetBestMatchingStatsSingleMulti) { + Ragged tokens(GetCpuContext(), "[ [ [ 4 6 7 1 8 ] [ 4 3 7 1 8 ] " + " [ 4 3 2 1 8 ] [ 5 6 7 1 8 ] ] " + " [ [ 5 1 4 8 ] [ 5 1 2 8 ] " + " [ 5 3 4 8 ] ] ]"); + Array1 scores(GetCpuContext(), "[ 1 2 3 4 5 6 7 8 9 10 " + " 0 0 0 0 0 0 0 0 0 0 " + " 1 2 3 4 5 7 8 6 0 0 0 0 ]"); + Array1 counts(GetCpuContext(), "[ 1 1 1 1 1 1 1 1 1 1 " + " 0 0 0 0 0 0 0 0 0 0 " + " 1 1 1 1 1 1 1 1 0 0 0 0 ]"); + Array1 mean, var; + Array1 counts_out, ngram_order; + int32_t eos = 8, + min_token = 0, + max_token = 10, + max_order = 5; GetBestMatchingStats(tokens, scores, counts, eos, min_token, max_token, max_order, &mean, &var, &counts_out, &ngram_order); - Array1 ngram_order3_ref(GetCpuContext(), "[ 5 1 2 3 4 5 5 1 2 3 " - " 5 5 0 1 2 0 1 2 3 4 ]"); - K2_CHECK(Equal(ngram_order, ngram_order3_ref)); + Array1 mean_ref(GetCpuContext(), "[ 3.5 2 3 4 5 6 7 5.5 6.5 7.5 " + " 6 7 5.5 6.5 7.5 5.5 2 3 4 5 " + " 3 4.5 3 4 3 4.5 4.5 5 " + " 3 4.5 3 4 ]"); + Array1 var_ref(GetCpuContext(), "[ 6.25 0 0 0 0 0 0 6.25 6.25 6.25 " + " 0 0 8.25 6.25 6.25 8.25 0 0 0 0 " + " 4 6.25 0 0 4 6.25 5.25 1 " + " 4 5.25 0 0 ]"); + Array1 counts_out_ref(GetCpuContext(), "[ 2 1 1 1 1 1 1 2 2 2 " + " 1 1 0 2 2 0 1 1 1 1 " + " 2 2 1 1 2 2 0 2 " + " 2 0 1 1 ]"); + Array1 ngram_order_ref(GetCpuContext(), "[ 5 1 2 3 4 5 5 1 2 3 " + " 5 5 0 1 2 0 1 2 3 4 " + " 5 5 1 2 5 5 0 1 " + " 5 0 1 2 ]"); + K2_CHECK(Equal(mean, mean_ref)); + K2_CHECK(Equal(var, var_ref)); + K2_CHECK(Equal(counts_out, counts_out_ref)); + K2_CHECK(Equal(ngram_order, ngram_order_ref)); } } // namespace k2 diff --git a/k2/python/k2/utils.py b/k2/python/k2/utils.py index c9521f0f3..7e1869f0e 100644 --- a/k2/python/k2/utils.py +++ b/k2/python/k2/utils.py @@ -660,10 +660,89 @@ def random_fsa_vec(min_num_fsas: int = 1, def get_best_matching_stats(tokens: _k2.RaggedInt, scores: torch.Tensor, counts: torch.Tensor, eos: int, min_token: int, - max_token: int,max_order: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - ''' + max_token: int, max_order: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # noqa + '''For "query" sentences, this function gets the mean and variance of + scores from the best matching words-in-context in a set of provided "key" + sentences. This matching process matches the word and the words preceding + it, looking for the highest-order match it can find (it's intended for + approximating the scores of models that see only left-context, + like language models). The intended application is in estimating the scores + of hypothesized transcripts, when we have actually computed the scores for + only a subset of the hypotheses. + + CAUTION: + This function only runs on CPU for now. + + Args: + tokens: + A ragged tensor of int32_t with 2 or 3 axes. If 2 axes, this represents + a collection of key and query sequences. If 3 axes, this represents a + set of such collections. + + 2-axis example: + [ [ the, cat, said, eos ], [ the, cat, fed, eos ] ] + 3-axis example: + [ [ [ the, cat, said, eos ], [ the, cat, fed, eos ] ], + [ [ hi, my, name, is, eos ], [ bye, my, name, is, eos ] ], ... ] + + where the words would actually be represented as integers, + The eos symbol is required if this code is to work as intended + (otherwise this code will not be able to recognize when we have reached + the beginnings of sentences when comparing histories). + bos symbols are allowed but not required. + + scores: + An one dim torch.tensor with scores.size() == tokens.NumElements(), + this is the item for which we are requesting best-matching values + (as means and variances in case there are multiple best matches). + In our anticipated use, these would represent scores of words in the + sentences, but they could represent anything. + counts: + An one dim torch.tensor with counts.size() == tokens.NumElements(), + containing 1 for words that are considered "keys" and 0 for + words that are considered "queries". Typically some entire + sentences will be keys and others will be queries. + eos: + The value of the eos (end of sentence) symbol; internally, this + is used as an extra padding value before the first sentence in each + collection, so that it can act like a "bos" symbol. + min_token: + The lowest possible token value, including the bos + symbol (e.g., might be -1). + max_token: + The maximum possible token value. Be careful not to + set this too large the implementation contains a part which + takes time and space O(max_token - min_token). + max_order: + The maximum n-gram order to ever return in the + `ngram_order` output; the output will be the minimum of max_order + and the actual order matched; or max_order if we matched all the + way to the beginning of both sentences. The main reason this is + needed is that we need a finite number to return at the + beginning of sentences. + Returns: + Returns a tuple of four torch.tensor (mean, var, counts_out, ngram_order) + mean: + For query positions, will contain the mean of the scores at the + best matching key positions, or zero if that is undefined because + there are no key positions at all. For key positions, + you can treat the output as being undefined (actually they + are treated the same as queries, but won't match with only + themselves because we don't match at singleton intervals). + var: + Like `mean`, but contains the (centered) variance + of the best matching positions. + counts_out: + The number of key positions that contributed to the `mean` + and `var` statistics. This should only be zero if `counts` + was all zero. + ngram_order: + The n-gram order corresponding to the best matching + positions found at each query position, up to a maximum of + `max_order`; will be `max_order` if we matched all + the way to the beginning of a sentence. ''' return _k2.get_best_matching_stats(tokens, scores, counts, eos, - min_token, max_token, max_order) \ No newline at end of file + min_token, max_token, max_order) From 4455eece4464dc1171b5a4d9a335b28a64b2edaa Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 20 Jul 2021 10:27:10 +0800 Subject: [PATCH 13/16] remove trailing whitespace --- k2/python/k2/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/k2/python/k2/utils.py b/k2/python/k2/utils.py index 7e1869f0e..d6abb3d9d 100644 --- a/k2/python/k2/utils.py +++ b/k2/python/k2/utils.py @@ -672,7 +672,7 @@ def get_best_matching_stats(tokens: _k2.RaggedInt, scores: torch.Tensor, only a subset of the hypotheses. CAUTION: - This function only runs on CPU for now. + This function only runs on CPU for now. Args: tokens: @@ -718,7 +718,7 @@ def get_best_matching_stats(tokens: _k2.RaggedInt, scores: torch.Tensor, The maximum n-gram order to ever return in the `ngram_order` output; the output will be the minimum of max_order and the actual order matched; or max_order if we matched all the - way to the beginning of both sentences. The main reason this is + way to the beginning of both sentences. The main reason this is needed is that we need a finite number to return at the beginning of sentences. From c9edaaa4cbe58be92f2d06498934ea6a6ccc3b5a Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 20 Jul 2021 10:50:29 +0800 Subject: [PATCH 14/16] Add test case that no keys in tokens --- k2/csrc/nbest.cu | 4 +++- k2/csrc/nbest_test.cu | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/k2/csrc/nbest.cu b/k2/csrc/nbest.cu index 7c1a30e98..fd33281ae 100644 --- a/k2/csrc/nbest.cu +++ b/k2/csrc/nbest.cu @@ -453,7 +453,7 @@ void GetBestMatchingStatsInternal(Ragged &tokens, const T *suffix_array_data = suffix_array.Data(); for (int32_t i = 0; i < suffix_array.Dim(); ++i) { // we reverse the sequence above, the order of counts and scores should be - // reversed accordingly, and make ensure that the counts and scores be zero + // reversed accordingly, and make sure that the counts and scores be zero // in the positions of eos and terminator. int32_t rindex = seq_len - 2 - suffix_array_data[i] - 1; reorder_counts_data[i] = rindex < 0 ? 0 : counts_data[rindex]; @@ -564,6 +564,8 @@ void GetBestMatchingStats(Ragged &tokens, Array1 *counts_out, Array1 *ngram_order) { ContextPtr &c = tokens.Context(); + K2_CHECK_EQ(c->GetDeviceType(), kCpu); + int32_t num_elements = tokens.NumElements(); K2_CHECK(mean); if (mean->Dim() != num_elements) { diff --git a/k2/csrc/nbest_test.cu b/k2/csrc/nbest_test.cu index e67446b48..8bc4600e3 100644 --- a/k2/csrc/nbest_test.cu +++ b/k2/csrc/nbest_test.cu @@ -371,6 +371,35 @@ TEST(AlgorithmTest, TestGetBestMatchingStatsSingle) { K2_CHECK(Equal(ngram_order, ngram_order_ref)); } +TEST(AlgorithmTest, TestGetBestMatchingStatsSpecial) { + Ragged tokens(GetCpuContext(), "[ [ 4 6 7 1 8 ] [ 4 3 7 1 8 ] " + " [ 4 3 2 1 8 ] [ 5 6 7 1 8 ] ]"); + Array1 scores(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 counts(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 mean, var; + Array1 counts_out, ngram_order; + int32_t eos = 8, + min_token = 1, + max_token = 8, + max_order = 2; + GetBestMatchingStats(tokens, scores, counts, eos, min_token, max_token, + max_order, &mean, &var, &counts_out, &ngram_order); + Array1 mean_ref(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 var_ref(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 counts_out_ref(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + Array1 ngram_order_ref(GetCpuContext(), "[ 0 0 0 0 0 0 0 0 0 0 " + " 0 0 0 0 0 0 0 0 0 0 ]"); + K2_CHECK(Equal(mean, mean_ref)); + K2_CHECK(Equal(var, var_ref)); + K2_CHECK(Equal(counts_out, counts_out_ref)); + K2_CHECK(Equal(ngram_order, ngram_order_ref)); +} + TEST(AlgorithmTest, TestGetBestMatchingStatsSingleMulti) { Ragged tokens(GetCpuContext(), "[ [ [ 4 6 7 1 8 ] [ 4 3 7 1 8 ] " " [ 4 3 2 1 8 ] [ 5 6 7 1 8 ] ] " From 6d371d810719bf35136138cd0d32ba030b67688b Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 20 Jul 2021 17:25:48 +0800 Subject: [PATCH 15/16] Fix code style --- k2/python/csrc/torch/nbest.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/k2/python/csrc/torch/nbest.cu b/k2/python/csrc/torch/nbest.cu index 57bd89697..f4d987ea1 100644 --- a/k2/python/csrc/torch/nbest.cu +++ b/k2/python/csrc/torch/nbest.cu @@ -20,7 +20,7 @@ * limitations under the License. */ -#include "k2/python/csrc/torch/nbest.h" +#include #include "k2/csrc/context.h" #include "k2/csrc/device_guard.h" @@ -28,6 +28,7 @@ #include "k2/csrc/nbest.h" #include "k2/csrc/nvtx.h" #include "k2/csrc/tensor_ops.h" +#include "k2/python/csrc/torch/nbest.h" #include "k2/python/csrc/torch/torch_util.h" namespace k2 { @@ -37,8 +38,8 @@ static void PybindGetBestMatchingStats(py::module &m) { "get_best_matching_stats", [](Ragged &tokens, torch::Tensor scores, torch::Tensor counts, int32_t eos, int32_t min_token, int32_t max_token, - int32_t max_order) -> std::tuple { + int32_t max_order) -> tuple { DeviceGuard guard(tokens.Context()); Array1 scores_array = FromTorch(scores); Array1 counts_array = FromTorch(counts); @@ -58,4 +59,4 @@ static void PybindGetBestMatchingStats(py::module &m) { void PybindNbest(py::module &m) { k2::PybindGetBestMatchingStats(m); -} \ No newline at end of file +} From c04e7a878faf8e0f940e694ebbb2d9659a796844 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 23 Jul 2021 15:18:46 +0800 Subject: [PATCH 16/16] wrap ragged.cat for float; fix review comments --- k2/csrc/CMakeLists.txt | 2 +- k2/python/csrc/torch/nbest.cu | 4 ++-- k2/python/csrc/torch/ragged_ops.cu | 1 + k2/python/k2/utils.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index db48d8b91..a97e291e8 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -88,7 +88,7 @@ else() endif() # the target -add_library(context ${context_srcs} ${context_cc_srcs}) +add_library(context ${context_srcs}) target_compile_definitions(context PUBLIC K2_TORCH_VERSION_MAJOR=${K2_TORCH_VERSION_MAJOR}) target_compile_definitions(context PUBLIC K2_TORCH_VERSION_MINOR=${K2_TORCH_VERSION_MINOR}) diff --git a/k2/python/csrc/torch/nbest.cu b/k2/python/csrc/torch/nbest.cu index f4d987ea1..ad2b7f094 100644 --- a/k2/python/csrc/torch/nbest.cu +++ b/k2/python/csrc/torch/nbest.cu @@ -38,8 +38,8 @@ static void PybindGetBestMatchingStats(py::module &m) { "get_best_matching_stats", [](Ragged &tokens, torch::Tensor scores, torch::Tensor counts, int32_t eos, int32_t min_token, int32_t max_token, - int32_t max_order) -> tuple { + int32_t max_order) -> std::tuple { DeviceGuard guard(tokens.Context()); Array1 scores_array = FromTorch(scores); Array1 counts_array = FromTorch(counts); diff --git a/k2/python/csrc/torch/ragged_ops.cu b/k2/python/csrc/torch/ragged_ops.cu index 61a18c2e1..45496fba4 100644 --- a/k2/python/csrc/torch/ragged_ops.cu +++ b/k2/python/csrc/torch/ragged_ops.cu @@ -411,6 +411,7 @@ void PybindRaggedOps(py::module &m) { PybindArgMaxPerSublist(m); PybindArgMaxPerSublist(m); PybindCat(m); + PybindCat(m); PybindCat(m); PybindCreateRagged2(m); PybindCreateRagged2(m); diff --git a/k2/python/k2/utils.py b/k2/python/k2/utils.py index d6abb3d9d..7b5bb9a24 100644 --- a/k2/python/k2/utils.py +++ b/k2/python/k2/utils.py @@ -693,7 +693,7 @@ def get_best_matching_stats(tokens: _k2.RaggedInt, scores: torch.Tensor, bos symbols are allowed but not required. scores: - An one dim torch.tensor with scores.size() == tokens.NumElements(), + A one dim torch.tensor with scores.size() == tokens.NumElements(), this is the item for which we are requesting best-matching values (as means and variances in case there are multiple best matches). In our anticipated use, these would represent scores of words in the