Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Plan for multi pass n-best rescoring #232

Open
danpovey opened this issue Jul 11, 2021 · 7 comments
Open

Plan for multi pass n-best rescoring #232

danpovey opened this issue Jul 11, 2021 · 7 comments

Comments

@danpovey
Copy link
Contributor

[Guys, I have gym now so I'll submit this and write the rest of this later today. ]

I am creating an issue to describe a plan for multi-pass n-best-list rescoring. This will also require
new code in k2, I'll create a separate issue.
The scenario is that we have a CTC or LF-MMI model and we do the 1st decoding pass from that.
Anything that we can do with lattices, we do first (e.g. including any FST-based LM rescoring).
Let the possibly-LM-rescored lattice be the starting point for the n-best rescoring process.

The first step is to generate a long n-best list for each lattice by calling RandomPaths() with a largish number,
like 1000. We then choose unique paths based on token sequences, where 'token' is whatever type of token
we are using in the transformer and RNNLM-- probably word pieces. That is, we use inner_labels='tokens'
when doing the composition with the CTC topo when making the decoding graph, and these get propagated
to the lattices, so we can use lats.tokens and remove epsilons and pick the unique paths.

I think we could have a data structure called Nbest-- we could draft this in snowfall for now and later move
to k2-- that contains an Fsa and also a _k2.RaggedShape that dictates how each of the paths relate to the
original supervision segments. But I guess we could draft this pipeline without the data structure.

Supposing we have the Nbest with ragged numbers of paths, we can then add epsilon self-loops and
intersect it with the lattices, after moving the 'tokens' to the 'labels' of the lattices; we'd then
get the 1-best path and remove epsilons so that we get an Nbest that has just the best path's
tokens and no epsilons.
(We could define, in class Nbest, a form of intersect() that does the right thing when composing with an Fsa
representing an FsaVec; we might also define wrappers for some Fsa operations so they work also on Nbest).

So at this point we have an Nbest with ragged numbers of paths up to 1000 (depending how many unique
paths we got) and that is just a linear sequence of arcs, one per token; and it has costs defined per
token. (It may also have other types of label and cost that were passively inherited). The way we allocate
these costs, e.g. of epsilons and token-repeats, to each token will of course be a little arbitrary-- it's a function
of how the epsilon removal algorithm works-- and we can try to figure out later on whether it needs to be changed
somehow.

We get the total_scores of this Nbest object; they will be used in determining which ones to use in the first
n-best list that we rescore. We can define its total_scores() function so that it returns it as a ragged array,
which it logically is.

@danpovey
Copy link
Contributor Author

OK, the next step is to determine the subset of paths in the Nbest object to rescore. The input to this process is the ragged array of total_scores that we obtained as mentioned above from composing with the lattices, and the immediate output of this would be RaggedInt/Ragged<int32_t> containing the subset of idx01's into the Nbest object that we want to retain. [This will be regular, i.e. we keep the same number from each supervision, even if this means having to use repeats. We'll have to figure out later what to do in case no paths survived in one of the supervisions.] We can use the shape of this to create the new Nbest object, indexing the Fsa of the original Nbest object with the idx01's to get the correct subset. For the very first iteration of our code we can just have this take the most likely n paths, although this is likely not optimal (might not have enough diversity). We can figure this out later. So at this point we still have an Nbest object, but it has a regular structure so will be easier to do rescoring with. Note: it is important that we have the original acoustic and LM scores per token (as the 'scores' in the FSA0, because we will later have a prediction scheme that makes use of these.

@danpovey
Copy link
Contributor Author

danpovey commented Jul 11, 2021

Any rescoring processes we have (e.g. LM rescoring, transformer decoding) should produce an Nbest object with the exact same structure as the one produced in the comment above, i.e. with a regular number of paths per supervision, like 10.

We'll need this exact same structure to be preserved so that our process for finding the n-best paths to rescore will work.
It is a selection process, where, from the paths that we have not selected in the 1st round of rescoring, we compute the expected total-score-after-rescoring of the path as a Gaussian distribution, and we rank them by the probability of being better than the best path from the 1st round. This probability requires the Gaussian integral, but we just need ranks so we can rank them by sigma value: i.e. the position, in standard deviations of the distribution, of the best score from the 1st round of rescoring.

This will require us to train a simple model to predict the total-score of a path. For each word-position in each of the remaining paths (i.e. that were not selected in the 1st pass), we want to predict the score for that position after rescoring as a Gaussian. Let an "initial-score" be an elements of the .scores of the n-best lists before neural rescoring, and a "final-score" be an element of the .scores of the n-best lists after neural rescoring.
For each position that we want to predict, the inputs are:
- The initial-score
- The mean and variance of the set of final-scores from the most closely matching positions that we rescored in the 1st round, together with the longest-matching n-gram order (i.e. was n equal to 0, 1, 2, 3..., taking a complete match including BOS to be infinity or some large maximum).
We train some simple model- possibly a neural model with very few neurons, or several linear regressions with multiple categories based on n-gram order- that predicts the mean and variance of a Gaussian to model the final-score. We will have to create some interface for this model to encapsulate it nicely.

The inputs to this model include the mean and variance of the best-matching positions, and a n-gram order. What I mean here, is: for a particular position in a path, we find the longest-matching sequence (i.e. up to and including this word) in any of the n-best lists that we actually rescored; and if there are multiple with the same longest length, we treat them as a set (if there is just one, the variance would be 0). We can also provide this count to the model. The mean and variance means the mean and variance of the scores at those longest-matching positions.

Now, it might look like this process of finding this set of longest-matching words, and computing the mean and variance of the scores, would be very time-consuming. Actually it can be done very efficiently (linear time in the total number of words we are processing, including words in paths that we selected in the 1st pass and those we did not, i.e. queries, and keys), although the algorithms will need to be done on CPU for now because they are too complex to implement on GPU in a short timeframe. I'll describe these algorithms in the next comment on this issue.

@danpovey
Copy link
Contributor Author

Let me first describe an internal interface for the code that gets the (mean,variance,ngram_order) of the best matching positions that rescored in the 1st round. I'm choosing a level of interface that will let you know the basic picture, but there will be other interfaces above and below. Something like this, assuming it's in Python:

 def get_best_matching_stats(tokens: k2.RaggedInt, scores: Tensor, counts: Tensor, eos: int, min_token: int, max_token: int, max_order: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
          """
     For "query" sentences, this function gets the mean and variance of scores from the best matching 
    words-in-context in a set of of provided "key" sentences.  This matching process matches the word and
    the words preceding it, looking for the highest-order match it can find.  It is an efficient implementation
    using suffix arrays (done on CPU for now, since the implementation is not very trivial).  The intended
    application is in estimating the scores of hypothesized transcripts, when we have actually computed
     the scores for only a subset of the hypotheses.

         tokens: a k2 ragged tensor with 3 axes, where the 1st axis is over separate utterances the 2nd axis is
             over different elements of an n-best list, and the 3rd axis is over words.  Some sub-sub-lists represent
             queries, some represent keys to be queried, see scores and counts. This would likely be something like
             the following:
                [ [ [ the, cat, said, eos ], [ the, cat, fed, eos ]], [ [ hi, my, name, is, eos ], [ bye, my, name, is, eos ] ], ... ]
             where the words would actually be the corresponding integers, and they might not even correspond to
             words (might be BPE pieces).
          scores:  a torch.Tensor with shape equal to (tokens.num_elements(),) with dtype==float.
               These are the values that we want to get the mean and variance of (would likely be the
                scores of words or tokens after a first round of n-best list rescoring).  They only represent
                scores in "key positions" (meaning: positions where the corresponding counts value is nonzero); in 
               "query positions", where the counts value is 0, the score is required to be zero.
          counts: a torch.Tensor with shape equal to (tokens.num_elements(),), with dtype==torch.int32,
               where the values should be 1 for positions where the 'scores' are nontrivially defined (representing
               keys) and 0 for positions where the 'scores' are zero (representing queries).   In our example,
             let's take counts to be: [ [ [ 1, 1, 1, 1 ], [ 0, 0, 0, 0 ] ], [ [ 1, 1, 1, 1, 1 ], [ 0, 0, 0, 0, 0 ] ], ... ]
       eos:  The value of the eos (end of sentence) symbol; this is used as an extra padding value before
             the first path for of each utterance, to ensure consistent behavior when matching past the
             beginning of the sentence.
        min_token:  the lowest value of token that might be included in `tokens`, including BOS and EOS symbols;
               may be negative, like -1.
        max_token: might equal the vocabulary size, or simply the maximum token value included
              in this particular example.
        max_order:  the maximum n-gram order to ever match; will be used as a limit on the
              `ngram_order` returned (but not on the actual length of match), and also will be 
               used when a match extends all the way to
              the beginning of a sentence including the implicit beginning-of-sentence symbol.
             (Note: if there is also an explicit bos symbol at the beginning of each sentence, it doesn't
              matter).
     Returns a tuple (mean, var, count, ngram_order), where:
         mean is a torch.Tensor with shape equal to (tokens.num_elements(),), with dtype==torch.float,
          representing the mean of the scores over the set of longest-matching key positions;
          this is defined for both key positions and the query positions, although the caller may not be interested
          in the value at key positions.
        var is a torch.Tensor with shape equal to (tokens.num_elements(),), with dtype==torch.float,
          representing the variance of the scores over the set of longest-matching key positions.
          This is expected to be zero at positions where count equals 1.
        count is a torch.Tensor with shape equal to (tokens.num_elements(),), with dtype==torch.int32,
          representing the number of longest-matching key positions.  This will be 1 if there was only
          a single position of the longest-matching order, and otherwise a larger number (note:
          if no words at all matched, ngram_order would be zero and the mean and variance would encompass
          all positions in all paths for the current utterance.)
       ngram_order  is a torch.Tensor with shape equal to (tokens.num_elements(),), with dtype==torch.int32,
          representing the ngram order of the best-matching word-in-context to the current word, up to
          max_order; or max_order in the case where we match up to the end of a sentence.  Example:
          in the case of 'name', in the 2nd sentence of the 2nd utterance, the ngram_order would be
          2 corresponding to the longest-matching sequence "my name".  In the case of 'fed' in the 2nd
          sentence of the 2nd utterance, the ngram_order would be 0.   In the case of 'cat' in the 2nd
          sentence of the 1st utterance, the ngram_order would equal max_order because we match
          all the way to the beginning of the sentence.
          """
      pass

The implementation of this function will use suffix arrays. For now everything will be done on the CPU. The basic plan is as follows; let's say we do it separately for each utterance.
We reverse the order of the words (and possibly utterances; utterance order doesn't matter though), and then append an extra eos symbol, add min_token+1 to everything to avoid zero and negative values, and append a zero, so that, for the 1st utterance above, we'd have:
[ eos+n, fed+n, cat+n, the+n, eos+n, said+n, cat+n, the+n, eos+n, 0 ],
where n equals min_token + 1.
Next we compute the suffix array which is an array of int32, of the same size as the list above, which is a lexicographical sorting of the suffixes of the sentence starting at each position. This can be done reasonably simply in O(n) time, see for example, the C++ code in:
https://algo2.iti.kit.edu/documents/jacm05-revised.pdf
Next we need to compute the LCP array from the suffix array (array of lengths least common prefixes between successive positions, see:)
https://www.geeksforgeeks.org/%C2%AD%C2%ADkasais-algorithm-for-construction-of-lcp-array-from-suffix-array/
The suffix array is an efficient data structure that can be used to simulate algorithms on "suffix tries"; a suffix trie is a compressed tree of suffixes of the string.
Viewed as an algorithm on the suffix trie, we can compute the things we need as follows:

  • Compute the inclusive-sum of the scores, squared scores, and counts, in the order given by the suffix array, i.e. index those quantities by the suffix array and then compute the inclusive sum.
  • The means and variances we need can be computed as the partial sums of the scores, squared scores and counts on sub-trees of the suffix trie; these correspond to differences between elements of those inclusive-sum between the positions in the suffix array that correspond to those subtrees in the suffix trie. The ngram orders involved correspond to the order of the lcp-intervals; see here http://www.mi.fu-berlin.de/wiki/pub/ABI/RnaSeqP4/enhanced-suffix-array.pdf. We always need this information from the tightest enclosing lcp-interval that has nonzero count.
  • In http://www.mi.fu-berlin.de/wiki/pub/ABI/RnaSeqP4/enhanced-suffix-array.pdf, see "Algorithm 1". This enumerates all of the non-leaf nodes of the suffix trie, or equivalently, all the lcp-intervals in the suffix array. It involves a stack. I propose to go through that traversal algorithm twice. The first time, we allocate numbers to each lcp-interval, based on the time it was first pushed to the stack; and we record in an array indexed by this number, its (lpc-value, left-boundary, right-boundary) and the the number of its parent lcp-interval; and also the number of its closest enclosing parent lcp-interval that has nonzero count. The parent information can be obtained as in Algorithm 2, see lastInterval. We then go through this list linearly and for each interval we record the number of the closest enclosing parent interval that has nonzero count. (This requires recursion but is still linear-time when amortized if we save the values computed in the recursion... the reason this is not super trivial is that parents can both precede and follow children in this numbering). The second time we go through Algorithm 1, we record the stats for each 'i' value, i.e. each suffix in the suffix array, and we produce the output. This requires us to know the immediate parent lcp-interval for each 'i' value. I have to figure this out, it doesn't look like it should be too hard.

@csukuangfj
Copy link
Collaborator

I will first implement the ideas in the first comment, i.e., the Nbest class, and try it in #198
Will go on to the next comments after it is done.

@danpovey
Copy link
Contributor Author

Incidentally, regarding padding, speechbrain has something called undo_padding
speechbrain/speechbrain#751 (comment)
which might possibly be useful. This is just something I noticed; if you disagree please ignore it.

@pkufool
Copy link
Contributor

pkufool commented Jul 19, 2021

For the very first iteration of our code we can just have this take the most likely n paths, although this is likely not optimal (might not have enough diversity).

So, these n paths(after rescoring) will be the keys to calculate mean and variance, and the other paths not selected will be queries. Is it right?

@danpovey
Copy link
Contributor Author

Yes.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants