-
Notifications
You must be signed in to change notification settings - Fork 42
Trace frame-level scores using lattice #248
Comments
Additional information:
|
Please refer to the help doc of There are two extra optional arguments: def intersect_dense(a_fsas: Fsa,
b_fsas: DenseFsaVec,
output_beam: float,
a_to_b_map: Optional[torch.Tensor] = None,
seqframe_idx_name: Optional[str] = None,
frame_idx_name: Optional[str] = None) -> Fsa: You can use lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0, seqframe_idx_name='seqframe', frame_idx_name='frame') After the above call, the resulting
In your case, the supervision contains only one utterance, so You can invoke [EDITED]: The whole process is also differentiable. |
Thanks for the reply It seems that Previously, I did not notice the API
After that, I try to compute the frame-level score by
And the results of the two versions cannot match.
Currently, I don't know which version is correct. |
arcs = lats.arcs.values()[:, :2]
# arcs is a 2-D torch.int32 tensor
for idx, (src, dst) in enumerate(arcs.tolist()):
# note src is not used and you can replace it with an underscore _
frame_idx = lats.frame[idx]
# now you konw the state `dst` belongs to the frame `frame_idx`
# You can add the forward_score of this state to a list corresponding to the frame `frame_idx`
#
# Caution: You have to avoid adding `dst` state multiple times
# At this point, you know the states corresponding to each frame, you can use `log-sum-exp` to combine them. |
As I posted above, you can iterate over the arcs; for each arc, you can get its frame_idx and dest_state. If you have multiple utterances, then you have to use |
Note: I have re-edited the demo code. |
Please use a small
A frame can correspond to multiple states, while a state belongs to only one frame. |
You can note down which state belongs to which frame.
You can get the total scores for frame 0 using
Note: For the last frame |
@csukuangfj
|
I'm also worried about the scores on the final states: |
For arc |
I have dumped the lattice as below.
In my
Follow this advice, I have also revised my
and the output of this is like below. It only ignore the start state Use the new
Note: |
As shown in the lattice above, arc Thanks :) |
I just created a colab notebook (see https://colab.research.google.com/drive/1iyc_q8aHuKd-RZxtYv9EqfyjB2QZDSOx?usp=sharing) The following code you posted:
is not equivalent to the one where the |
My observation is the same: the |
If I don't misunderstand: If this is true, I should compute the scores on the first T-1 frames by the Also, I should never call I'm still curious about: |
Maybe @danpovey has more to say about it.
I was explaining why the |
Thanks! As @csukuangfj advised, I have rewritten the two methods. currently, the results can match.
The results:
|
Hi, team!
I have encountered a problem with k2 in my code. Below is the description of this problem.
For a
nnet_output
with shape[B, T, D]
, I am trying to calculate the scores on a graph (MMI numerator or denominator) with any prefix segment ofnnet_output
, namelynnet_output[:, :t, :]
, wheret
is any index smaller than T (the total length in time axis).Currently, I implement it by a loop. But this leads to much computation. My code is below
Could these scores be calculated by parsing the
lats
obtained from the wholennet_output
, which means we can calculate them with only onek2.intersect_dense
? Approximation is also ok for me.Thanks for your help ! :)
The text was updated successfully, but these errors were encountered: