-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Export CTC decoding algorithm to sherpa #1093
Conversation
k2/torch/csrc/CMakeLists.txt
Outdated
@@ -65,3 +87,21 @@ if(K2_ENABLE_TESTS) | |||
k2_add_torch_test(${source}) | |||
endforeach() | |||
endif() | |||
|
|||
file(MAKE_DIRECTORY | |||
DESTINATION |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DESTINATION |
k2/torch/csrc/CMakeLists.txt
Outdated
${PROJECT_BINARY_DIR}/include/k2 | ||
) | ||
|
||
install(TARGETS k2_torch_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
install(TARGETS k2_torch_api | |
install(TARGETS k2_torch_api k2_torch |
k2/torch/bin/CMakeLists.txt
Outdated
#---------------------------------------- | ||
# CTC decoding | ||
#---------------------------------------- | ||
set(ctc_decode_srcs ctc_decode.cu) | ||
set(ctc_decode_srcs ctc_decode.cu ${feature_srcs}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will recompile feature_srcs for each binary. Shall we make it a library that can be shared?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I thinks so. so, let's change back to its original
I suggest that we create a new PR to bind the exposed APIs to Python and provide python APIs and examples to decode models trained using CTC loss from various frameworks, such as icefall, nemo, espnet, speechbrain, wenet, etc. |
I think we MUST create a new PR, because the binding and demo code will be in sherpa, not k2. |
k2 only requires log_softmax_out + TLG for decoding, so it is easier to use. |
OK, I got your idea, I think to do that the functions to be wrapped is in |
int32_t min_activate_states, | ||
int32_t max_activate_states, | ||
int32_t subsampling_factor) { | ||
FsaClassPtr GetLattice(torch::Tensor log_softmax_out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest that we provide only a single method Decode()
to give the results directly.
In the current approach, users have to call BestPath
on the returned lattice and that is the only function that users can call for the returned lattice.
That is an implementation detail and we can hide it from the users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest that we provide only a single method
Decode()
to give the results directly.In the current approach, users have to call
BestPath
on the returned lattice and that is the only function that users can call for the returned lattice.That is an implementation detail and we can hide it from the users.
I thought we can use this lattice to do rescoring, we can implement one_best_decoding
and ngram_rescoring
in sherpa.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is no need to hide lattice.
See k2-fsa/sherpa#177 for our plans.
kaldifeat_core
fromk2_torch
. We can make the binaries depend onkaldifeat_core
directly.k2_torch
ink2/cmake/k2Config.cmake.in
Line 48 in c3a7404
torch_api.h
HLG.pt
. The returned type can bestd::shared_ptr<k2::FsaClass>
. We can define an aliasFsaClassPtr
for it, like RaggedShapePtrk2::CtcTopo()
. It should also return a value of typeFsaClassPtr
.- log_softmax_out: a 3-D tensor of shape (N, T, C)
- log_softmax_out_lens, a 1-D tensor of shape (N,)
- FsaClassPtr, can be either a CtcTopo or an HLG
and it returns
std::vector<std::vector<int32_t>>