Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Export CTC decoding algorithm to sherpa #1093

Merged
merged 7 commits into from
Nov 9, 2022

Conversation

pkufool
Copy link
Collaborator

@pkufool pkufool commented Nov 6, 2022

See k2-fsa/sherpa#177 for our plans.

  • Move https://github.com/k2-fsa/k2/blob/master/k2/csrc/torch_api.h to https://github.com/k2-fsa/k2/tree/master/k2/torch/csrc
  • Remove the dependency kaldifeat_core from k2_torch. We can make the binaries depend on kaldifeat_core directly.
  • Export the library k2_torch in
    set(K2_LIBRARIES k2_torch_api k2_log k2context k2fsa)
  • Add the following functions to torch_api.h
    • A function to load HLG.pt. The returned type can be std::shared_ptr<k2::FsaClass>. We can define an alias FsaClassPtr for it, like RaggedShapePtr
    • A function to wrap k2::CtcTopo(). It should also return a value of type FsaClassPtr.
    • A function for CTC/HLG decoding. It takes the following inputs:
      - log_softmax_out: a 3-D tensor of shape (N, T, C)
      - log_softmax_out_lens, a 1-D tensor of shape (N,)
      - FsaClassPtr, can be either a CtcTopo or an HLG
      and it returns std::vector<std::vector<int32_t>>

@pkufool pkufool changed the title Export CTC decoding algorithm to sherpa [WIP] Export CTC decoding algorithm to sherpa Nov 6, 2022
@@ -65,3 +87,21 @@ if(K2_ENABLE_TESTS)
k2_add_torch_test(${source})
endforeach()
endif()

file(MAKE_DIRECTORY
DESTINATION
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DESTINATION

${PROJECT_BINARY_DIR}/include/k2
)

install(TARGETS k2_torch_api
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
install(TARGETS k2_torch_api
install(TARGETS k2_torch_api k2_torch

#----------------------------------------
# CTC decoding
#----------------------------------------
set(ctc_decode_srcs ctc_decode.cu)
set(ctc_decode_srcs ctc_decode.cu ${feature_srcs})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will recompile feature_srcs for each binary. Shall we make it a library that can be shared?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I thinks so. so, let's change back to its original

@csukuangfj
Copy link
Collaborator

I suggest that we create a new PR to bind the exposed APIs to Python and provide python APIs and examples to decode models trained using CTC loss from various frameworks, such as icefall, nemo, espnet, speechbrain, wenet, etc.

@pkufool
Copy link
Collaborator Author

pkufool commented Nov 8, 2022

I suggest that we create a new PR to bind the exposed APIs to Python and provide python APIs and examples to decode models trained using CTC loss from various frameworks, such as icefall, nemo, espnet, speechbrain, wenet, etc.

I think we MUST create a new PR, because the binding and demo code will be in sherpa, not k2.

@csukuangfj
Copy link
Collaborator

csukuangfj commented Nov 8, 2022

k2 only requires log_softmax_out + TLG for decoding, so it is easier to use.
People can just import k2 and provide the required inputs for decoding and they may not want to install sherpa.

@pkufool
Copy link
Collaborator Author

pkufool commented Nov 8, 2022

k2 only requires log_softmax_out + TLG for decoding, so it is easier to use. People can just import k2 and provide the required inputs for decoding and they may not want to install sherpa.

OK, I got your idea, I think to do that the functions to be wrapped is in k2/torch/* not in torch_api.h. It does not relate to this PR, will do it in another change, thanks!

int32_t min_activate_states,
int32_t max_activate_states,
int32_t subsampling_factor) {
FsaClassPtr GetLattice(torch::Tensor log_softmax_out,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest that we provide only a single method Decode() to give the results directly.

In the current approach, users have to call BestPath on the returned lattice and that is the only function that users can call for the returned lattice.

That is an implementation detail and we can hide it from the users.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest that we provide only a single method Decode() to give the results directly.

In the current approach, users have to call BestPath on the returned lattice and that is the only function that users can call for the returned lattice.

That is an implementation detail and we can hide it from the users.

I thought we can use this lattice to do rescoring, we can implement one_best_decoding and ngram_rescoring in sherpa.

Copy link
Collaborator Author

@pkufool pkufool Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to hide lattice.

@pkufool pkufool merged commit e552812 into k2-fsa:master Nov 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants