-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Build ctc graph from symbols in batch mode #776
Conversation
Maybe a naive question -- what is the difference between this implementation, and doing it via Python-level k2 operations? I remember there was some code from @csukuangfj for MMI that batched all the FSA ops for a given list of utterances. Just curious what is the gain. |
@pzelasko (1) There are no optional silences before and after each word in this pull-request. |
Create an FsaVec containing ctc graph FSAs, given a list of sequences of | ||
symbols | ||
|
||
@param [in] symbols Input symbol sequences (must not contain |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add a check inside the kernel that none of the input symbols is -1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I add a checking in the kernel set_num_arcs
, which will enumerate all the symbols. I think it's enough.
k2/csrc/fsa_algo.cu
Outdated
sym_state_idx01 = state_idx01 / 2 - fsa_idx0, | ||
remainder = state_idx01 % 2, | ||
current_num_arcs = 2; // normally there are two arcs, self-loop | ||
// and arc points to the next state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// and arc points to the next state | |
// and arc pointing to the next state |
k2/csrc/fsa_algo.cu
Outdated
} else { | ||
int32_t current_symbol = symbol_data[sym_state_idx01], | ||
// we set the next symbol of the last symbol to -1, so | ||
// the following if clause will always be true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain the comment:
so the following if clause will always be true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It means that the last symbol state would always have 3 arcs. we handle next_symbol
here, first, to avoid segment fault error, second, to confirm that current_symbol != next_symbol
so we will assign the last symbol state 3 arcs.
will explain more in the docs.
}); | ||
|
||
ExclusiveSum(num_states_for, &num_states_for); | ||
Array1<int32_t> &fsa_to_states_row_splits = num_states_for; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to introduce another name for num_states_for
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After doing ExclusiveSum, num_states_for
is actually the row_splits
, using different names here just for easy understanding, we'll use fsa_to_states_row_splits
to construct ragged_shape below.
k2/python/csrc/torch/fsa_algo.cu
Outdated
if (need_arc_map) tensor = ToTorch(arc_map); | ||
return std::make_pair(graph, tensor); | ||
}, | ||
py::arg("labels"), py::arg("need_arc_map") = true, py::arg("gpu_id")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
py::arg("labels"), py::arg("need_arc_map") = true, py::arg("gpu_id")); | |
py::arg("symbols"), py::arg("need_arc_map") = true, py::arg("gpu_id")); |
k2/python/csrc/torch/fsa_algo.cu
Outdated
m.def( | ||
"ctc_graph", | ||
[](const Ragged<int32_t> &symbols, bool need_arc_map = true, | ||
int32_t /*unused_gpu_id*/) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last argument has no default value but it is after an argument with default value.
It's not valid C++, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw the error in actions, it could compile successfully on linux, will fix it.
I think there should be a boolean option to specify the type of CTC topology: "standard" or "simplified", where the "standard" one makes the blank mandatory between a pair of identical symbols. |
Ok, will add the option. |
Great, thanks!
Feel free to merge when you guys think it's OK.
…On Thu, Jul 8, 2021 at 5:34 PM pkufool ***@***.***> wrote:
Add the option standard, default True, the standard one makes the blank
mandatory between a pair of identical symbols.
An example to demonstrate their difference is as follow:
[image: image]
<https://user-images.githubusercontent.com/11765074/124899049-50872780-e012-11eb-9d85-96ea49344410.png>
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#776 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAZFLO764YMISFYRLYEGCFTTWVWKJANCNFSM476T765Q>
.
|
// There is no arcs for final states | ||
if (sym_state_idx01 == sym_final_state) { | ||
current_num_arcs = 0; | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
} else { | |
} else if(!standard) { | |
current_num_arcs = 3; | |
} else { | |
// same as before the latest change | |
} |
For non-standard topo, current_num_arcs
is always 3. Put it into
a separate if statement can save some work.
Co-authored-by: Fangjun Kuang <[email protected]>
See k2-fsa/snowfall#220
The python api and results are as follows: