Skip to content

Commit

Permalink
add the decoder_prompt_id for whisper tokenizer (microsoft#775)
Browse files Browse the repository at this point in the history
* add the decoder_prompt_id for whisper tokenizer

* temporarily disable android prebuilt

* disable the prebuilt for android

* disable the prebuilt for android 2

* Add a unit test

* correct test ids
  • Loading branch information
wenbingl authored Jul 29, 2024
1 parent 620050f commit c3145b8
Show file tree
Hide file tree
Showing 10 changed files with 128,017 additions and 41 deletions.
6 changes: 3 additions & 3 deletions build.android
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ ANDROID_API_LEVEL=24
# build openssl and curl for azure ops
export ANDROID_API_LEVEL=${ANDROID_API_LEVEL}
export ANDROID_NDK_ROOT=${NDK_ROOT}
pushd "${SCRIPT_DIR}/prebuild"
./build_curl_for_android.sh ${abi_name}
popd
# pushd "${SCRIPT_DIR}/prebuild"
# ./build_curl_for_android.sh ${abi_name}
# popd

mkdir -p "${target_dir}"
pushd "${target_dir}"
Expand Down
16 changes: 16 additions & 0 deletions include/ortx_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const ch
extError_t ORTX_API_CALL OrtxTokenize(
const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, OrtxTokenId2DArray** output);

/**
* @brief Retrieves the decoder prompt IDs from the tokenizer.
*
* This function retrieves the decoder prompt IDs from the specified tokenizer.
*
* @param tokenizer A pointer to the OrtxTokenizer object.
* @param batch_size The size of the batch.
* @param lang The language for the Whisper model decoding, like 'en'. Can be NULL, which is no id in the output.
* @param task The task for the model, like 'translation' or 'transcribe'. Can be NULL, which is no id in the output.
* @param no_timestamps Flag indicating whether to include timestamps in the output. 1 is true, 0 is false.
* @param output A pointer to the OrtxTokenId2DArray object to store the output.
* @return An extError_t value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxGetDecoderPromptIds(
const OrtxTokenizer* tokenizer, size_t batch_size, const char* lang, const char* task, int no_timestamps, OrtxTokenId2DArray** output);

/** \brief Detokenize the input using the specified tokenizer
*
* \param tokenizer Pointer to the tokenizer object
Expand Down
27 changes: 26 additions & 1 deletion shared/api/c_api_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,32 @@ extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char
return extError_t();
}

extError_t ORTX_API_CALL OrtxGetDecoderPromptIds(const OrtxTokenizer* tokenizer, size_t batch_size, const char* lang,
const char* task, int no_timestamps, OrtxTokenId2DArray** output) {
if (tokenizer == nullptr || output == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}

auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
ReturnableStatus status = token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer);
if (!status.IsOk()) {
return status.Code();
}

std::vector<std::vector<extTokenId_t>> t_ids;
status = token_ptr->GetDecoderPromptIds(batch_size, lang, task, no_timestamps, t_ids);
if (!status.IsOk()) {
return status.Code();
}

auto result = std::make_unique<ort_extensions::TokenId2DArray>().release();
result->SetTokenIds(std::move(t_ids));
*output = static_cast<OrtxTokenId2DArray*>(result);

return extError_t();
}

extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, const OrtxTokenId2DArray* input,
OrtxStringArray** output) {
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
Expand Down Expand Up @@ -110,7 +136,6 @@ extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, const Or
*output = static_cast<OrtxStringArray*>(result);

return extError_t();
;
}

extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer, const extTokenId_t* input, size_t len,
Expand Down
82 changes: 77 additions & 5 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

using namespace ort_extensions;

TokenizerImpl::TokenizerImpl() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer){};
TokenizerImpl::~TokenizerImpl(){};
TokenizerImpl::TokenizerImpl() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer) {};
TokenizerImpl::~TokenizerImpl() {};

OrtxStatus TokenizerImpl::Load(const std::string& dir) {
tok_config_ = std::make_shared<ort_extensions::bpe::TokenJsonConfig>();
Expand All @@ -29,9 +29,8 @@ OrtxStatus TokenizerImpl::Load(const std::string& dir) {
return status;
}

OrtxStatus TokenizerImpl::BatchEncode(
const std::vector<std::string_view>& input,
std::vector<std::vector<extTokenId_t>>& t_ids) const {
OrtxStatus TokenizerImpl::BatchEncode(const std::vector<std::string_view>& input,
std::vector<std::vector<extTokenId_t>>& t_ids) const {
for (const auto& s : input) {
ortc::Tensor<int64_t> ts_output(&CppAllocator::Instance());
ortc::Tensor<std::string> ts_input = ortc::Tensor<std::string>(std::vector<std::string>{std::string(s)});
Expand Down Expand Up @@ -69,3 +68,76 @@ OrtxStatus TokenizerImpl::BatchDecode(const std::vector<span<extTokenId_t const>
OrtxStatus TokenizerImpl::Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const {
return detokenizer_->Id2Token(id, token, state);
}

static std::map<std::string, std::string> LANGUAGES = {
{"en", "english"}, {"zh", "chinese"}, {"de", "german"}, {"es", "spanish"}, {"ru", "russian"},
{"ko", "korean"}, {"fr", "french"}, {"ja", "japanese"}, {"pt", "portuguese"}, {"tr", "turkish"},
{"pl", "polish"}, {"ca", "catalan"}, {"nl", "dutch"}, {"ar", "arabic"}, {"sv", "swedish"},
{"it", "italian"}, {"id", "indonesian"}, {"hi", "hindi"}, {"fi", "finnish"}, {"vi", "vietnamese"},
{"he", "hebrew"}, {"uk", "ukrainian"}, {"el", "greek"}, {"ms", "malay"}, {"cs", "czech"},
{"ro", "romanian"}, {"da", "danish"}, {"hu", "hungarian"}, {"ta", "tamil"}, {"no", "norwegian"},
{"th", "thai"}, {"ur", "urdu"}, {"hr", "croatian"}, {"bg", "bulgarian"}, {"lt", "lithuanian"},
{"la", "latin"}, {"mi", "maori"}, {"ml", "malayalam"}, {"cy", "welsh"}, {"sk", "slovak"},
{"te", "telugu"}, {"fa", "persian"}, {"lv", "latvian"}, {"bn", "bengali"}, {"sr", "serbian"},
{"az", "azerbaijani"}, {"sl", "slovenian"}, {"kn", "kannada"}, {"et", "estonian"}, {"mk", "macedonian"},
{"br", "breton"}, {"eu", "basque"}, {"is", "icelandic"}, {"hy", "armenian"}, {"ne", "nepali"},
{"mn", "mongolian"}, {"bs", "bosnian"}, {"kk", "kazakh"}, {"sq", "albanian"}, {"sw", "swahili"},
{"gl", "galician"}, {"mr", "marathi"}, {"pa", "punjabi"}, {"si", "sinhala"}, {"km", "khmer"},
{"sn", "shona"}, {"yo", "yoruba"}, {"so", "somali"}, {"af", "afrikaans"}, {"oc", "occitan"},
{"ka", "georgian"}, {"be", "belarusian"}, {"tg", "tajik"}, {"sd", "sindhi"}, {"gu", "gujarati"},
{"am", "amharic"}, {"yi", "yiddish"}, {"lo", "lao"}, {"uz", "uzbek"}, {"fo", "faroese"},
{"ht", "haitian creole"}, {"ps", "pashto"}, {"tk", "turkmen"}, {"nn", "nynorsk"}, {"mt", "maltese"},
{"sa", "sanskrit"}, {"lb", "luxembourgish"}, {"my", "myanmar"}, {"bo", "tibetan"}, {"tl", "tagalog"},
{"mg", "malagasy"}, {"as", "assamese"}, {"tt", "tatar"}, {"haw", "hawaiian"}, {"ln", "lingala"},
{"ha", "hausa"}, {"ba", "bashkir"}, {"jw", "javanese"}, {"su", "sundanese"}, {"yue", "cantonese"}};

OrtxStatus TokenizerImpl::GetDecoderPromptIds(size_t batch_size, const char* lang, const char* task, int no_timestamps,
std::vector<std::vector<extTokenId_t>>& t_ids) const {
if (tokenizer_ == nullptr) {
return OrtxStatus(kOrtxErrorInvalidArgument, "Tokenizer is not loaded");
}
// since it was only supported by Whisper model, should we check it here?

auto convert_tokens_to_ids = [this](const std::string& token) -> extTokenId_t {
ortc::Tensor<int64_t> ts_output(&CppAllocator::Instance());
ortc::Tensor<std::string> ts_input = ortc::Tensor<std::string>(std::vector<std::string>{std::string(token)});
auto status = this->tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);
if (!status.IsOk()) {
return static_cast<extTokenId_t>(-1);
}
auto num = ts_output.NumberOfElement();
return static_cast<extTokenId_t>(ts_output.Data()[num / 2]); // get the middle token
};

auto translate_token_id = convert_tokens_to_ids("<|translate|>");
auto transcribe_token_id = convert_tokens_to_ids("<|transcribe|>");
auto notimestamps_token_id = convert_tokens_to_ids("<|notimestamps|>");
std::vector<extTokenId_t> ids;
ids.reserve(4);
if (lang != nullptr) {
auto lang_str = LANGUAGES.find(lang);
if (lang_str == LANGUAGES.end()) {
return OrtxStatus(kOrtxErrorInvalidArgument, "Invalid language");
}

std::string lang_token = "<|" + lang_str->first + "|>";
ids.push_back(convert_tokens_to_ids(lang_token));
}

if (task != nullptr) {
if (0 == strcmp(task, "translate") == 0) {
ids.push_back(translate_token_id);
} else if (0 == strcmp(task, "transcribe")) {
ids.push_back(transcribe_token_id);
} else {
return OrtxStatus(kOrtxErrorInvalidArgument, "Invalid task");
}
}

if (no_timestamps) {
ids.push_back(notimestamps_token_id);
}

t_ids.resize(batch_size, ids);
return {};
}
14 changes: 8 additions & 6 deletions shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ class TokenizerImpl : public OrtxObjectImpl {
public:
OrtxStatus Load(const std::string& dir);

OrtxStatus Tokenize(const std::vector<std::string_view>& input,
std::vector<std::vector<extTokenId_t>>& t_ids) const {
OrtxStatus Tokenize(const std::vector<std::string_view>& input, std::vector<std::vector<extTokenId_t>>& t_ids) const {
return BatchEncode(input, t_ids);
}

OrtxStatus Detokenize(const std::vector<span<extTokenId_t const>>& t_ids,
std::vector<std::string>& t_text) const {
OrtxStatus Detokenize(const std::vector<span<extTokenId_t const>>& t_ids, std::vector<std::string>& t_text) const {
return BatchDecode(t_ids, t_text);
}

Expand All @@ -41,11 +39,15 @@ class TokenizerImpl : public OrtxObjectImpl {
return status;
}

OrtxStatus BatchEncode(const std::vector<std::string_view>& input, std::vector<std::vector<extTokenId_t>>& t_ids) const;
OrtxStatus BatchEncode(const std::vector<std::string_view>& input,
std::vector<std::vector<extTokenId_t>>& t_ids) const;

OrtxStatus BatchDecode(const std::vector<span<extTokenId_t const>>& t_ids, std::vector<std::string>& t_text) const;

OrtxStatus Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state ) const;
OrtxStatus Id2Token(extTokenId_t id, std::string& token, BPEDecoderState** state) const;

OrtxStatus GetDecoderPromptIds(size_t batch_size, const char* lang, const char* task, int no_timestamps,
std::vector<std::vector<extTokenId_t>>& t_ids) const;

private:
std::string tokenizer_dir_;
Expand Down
1 change: 1 addition & 0 deletions shared/extensions_c.def
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ EXPORTS
OrtxStringArrayGetItem @14
OrtxTokenId2DArrayGetBatch @15
OrtxTokenId2DArrayGetItem @16
OrtxGetDecoderPromptIds @17
Loading

0 comments on commit c3145b8

Please sign in to comment.