Skip to content

Commit 711a2cf

Browse files
authored
add a convert_token_string_to_an_id API for the prompt ids (microsoft#794)
* add a convert token string to an id API for the prompt ids * fix the build issues on Linux
1 parent 6ce22f8 commit 711a2cf

File tree

8 files changed

+71
-42
lines changed

8 files changed

+71
-42
lines changed

include/ortx_tokenizer.h

+11
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const ch
3737
extError_t ORTX_API_CALL OrtxTokenize(
3838
const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, OrtxTokenId2DArray** output);
3939

40+
41+
/**
42+
* Converts a token to its corresponding ID.
43+
*
44+
* @param tokenizer The tokenizer object.
45+
* @param input The input token to be converted.
46+
* @param output Pointer to store the converted token ID.
47+
* @return The error code indicating the success or failure of the conversion.
48+
*/
49+
extError_t ORTX_API_CALL OrtxConvertTokenToId(const OrtxTokenizer* tokenizer, const char* token, extTokenId_t* id);
50+
4051
/**
4152
* @brief Retrieves the decoder prompt IDs from the tokenizer.
4253
*

operators/tokenizer/bpe_kernels.cc

+10-1
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,15 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
171171
return {};
172172
}
173173

174+
uint32_t KernelBpeTokenizer::GetTokenId(const std::string& token) const {
175+
auto id = bbpe_tokenizer_->GetAddedTokenId(token);
176+
if (id != bpe::kInvalidTokenId) {
177+
return id;
178+
}
179+
180+
return bbpe_tokenizer_->GetTokenId(token);
181+
}
182+
174183
std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
175184
int64_t max_length,
176185
bool compute_offset_mapping,
@@ -778,4 +787,4 @@ OrtxStatus JsonFastTokenizer::Compute(const ortc::Tensor<std::string>& input,
778787
std::optional<ortc::Tensor<int64_t>*> attention_mask,
779788
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const {
780789
return KernelBpeTokenizer::Compute(input, tokenize_output, attention_mask, offset_mapping);
781-
}
790+
}

operators/tokenizer/bpe_kernels.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ struct KernelBpeTokenizer {
3333
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
3434

3535
const std::string& ModelName() const { return model_name_; }
36+
uint32_t GetTokenId(const std::string& token) const;
3637

3738
protected:
3839
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
@@ -104,7 +105,7 @@ struct SpmTokenizer : KernelBpeTokenizer {
104105
}
105106
};
106107

107-
class JsonFastTokenizer : KernelBpeTokenizer {
108+
class JsonFastTokenizer : public KernelBpeTokenizer {
108109
public:
109110
JsonFastTokenizer();
110111
bool tiktoken_ = false;

operators/tokenizer/bpe_tokenizer.hpp

+19-20
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,8 @@ class BpeModel {
4444
}
4545
}
4646

47-
OrtxStatus Load(std::istream& vocab_stream,
48-
std::istream& merges_stream,
49-
const char* unk_token,
50-
const char* special_tokens,
51-
bool spm_converted) {
47+
OrtxStatus Load(std::istream& vocab_stream, std::istream& merges_stream, const char* unk_token,
48+
const char* special_tokens, bool spm_converted) {
5249
nlohmann::json tok_json;
5350
vocab_stream >> tok_json;
5451
tok_json.get_to(vocab_map_);
@@ -125,9 +122,7 @@ class BpeModel {
125122
return {};
126123
}
127124

128-
OrtxStatus Load(const json& bpe_model,
129-
const char* /* special_tokens */,
130-
bool spm_converted) {
125+
OrtxStatus Load(const json& bpe_model, const char* /* special_tokens */, bool spm_converted) {
131126
const json& vocab_json = bpe_model["vocab"];
132127
const json& merges_json = bpe_model["merges"];
133128
vocab_json.get_to(vocab_map_);
@@ -195,8 +190,7 @@ class BpeModel {
195190
}
196191

197192
OrtxStatus Load(std::unordered_map<std::string, uint32_t>& vocab,
198-
std::vector<std::pair<std::string, std::string>>& merges,
199-
const char* /* special_tokens */,
193+
std::vector<std::pair<std::string, std::string>>& merges, const char* /* special_tokens */,
200194
bool spm_converted) {
201195
vocab_map_ = vocab;
202196

@@ -207,7 +201,7 @@ class BpeModel {
207201
}
208202

209203
uint32_t index = 0;
210-
for (auto& tuple : merges){
204+
for (auto& tuple : merges) {
211205
std::string w1 = tuple.first;
212206
std::string w2 = tuple.second;
213207
int token_length = ort_extensions::narrow<int>(w1.length() + w2.length());
@@ -269,11 +263,10 @@ class BpeModel {
269263
return {};
270264
}
271265

272-
std::vector<std::string> BuildDecoder() const {
273-
return id2token_map_;
274-
}
266+
std::vector<std::string> BuildDecoder() const { return id2token_map_; }
275267

276-
// REF: https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
268+
// REF:
269+
// https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
277270
bpe::TokenPairs SplitByAddedAndSpecial(const ustring& input) const {
278271
// split by added tokens
279272
bpe::TokenPairs added_result;
@@ -343,9 +336,7 @@ class BpeModel {
343336
}
344337
}
345338

346-
const auto& ByteEncoder() const {
347-
return byte_encoder_;
348-
}
339+
const auto& ByteEncoder() const { return byte_encoder_; }
349340

350341
uint32_t GetTokenId(const std::string& key) const {
351342
auto it = vocab_map_.find(key);
@@ -356,10 +347,18 @@ class BpeModel {
356347
}
357348
}
358349

359-
const std::string& GetEndOfWordSuffix() const {
360-
return end_of_word_suffix_;
350+
uint32_t GetAddedTokenId(const std::string& key) const {
351+
size_t idx = 0;
352+
int id = added_tokens_.FindLongest(ustring(key), idx);
353+
if (idx == 0) {
354+
return bpe::kInvalidTokenId;
355+
}
356+
357+
return static_cast<uint32_t>(id);
361358
}
362359

360+
const std::string& GetEndOfWordSuffix() const { return end_of_word_suffix_; }
361+
363362
private:
364363
struct BpeNode {
365364
uint32_t id;

shared/api/c_api_tokenizer.cc

+15
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,21 @@ extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char
7676
return extError_t();
7777
}
7878

79+
extError_t ORTX_API_CALL OrtxConvertTokenToId(const OrtxTokenizer* tokenizer, const char* token, extTokenId_t* id) {
80+
if (tokenizer == nullptr || token == nullptr || id == nullptr) {
81+
ReturnableStatus::last_error_message_ = "Invalid argument";
82+
return kOrtxErrorInvalidArgument;
83+
}
84+
auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
85+
ReturnableStatus status = token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer);
86+
if (!status.IsOk()) {
87+
return status.Code();
88+
}
89+
90+
status = token_ptr->Token2Id(token, *id);
91+
return status.Code();
92+
}
93+
7994
extError_t ORTX_API_CALL OrtxGetDecoderPromptIds(const OrtxTokenizer* tokenizer, size_t batch_size, const char* lang,
8095
const char* task, int no_timestamps, OrtxTokenId2DArray** output) {
8196
if (tokenizer == nullptr || output == nullptr) {

shared/api/tokenizer_impl.cc

+4-15
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,9 @@ OrtxStatus TokenizerImpl::GetDecoderPromptIds(size_t batch_size, const char* lan
110110
}
111111
// since it was only supported by Whisper model, should we check it here?
112112

113-
auto convert_tokens_to_ids = [this](const std::string& token) -> extTokenId_t {
114-
ortc::Tensor<int64_t> ts_output(&CppAllocator::Instance());
115-
ortc::Tensor<std::string> ts_input = ortc::Tensor<std::string>(std::vector<std::string>{std::string(token)});
116-
auto status = this->tokenizer_->Compute(ts_input, ts_output, std::nullopt, std::nullopt);
117-
if (!status.IsOk()) {
118-
return static_cast<extTokenId_t>(-1);
119-
}
120-
auto num = ts_output.NumberOfElement();
121-
return static_cast<extTokenId_t>(ts_output.Data()[num / 2]); // get the middle token
122-
};
123-
124-
auto translate_token_id = convert_tokens_to_ids("<|translate|>");
125-
auto transcribe_token_id = convert_tokens_to_ids("<|transcribe|>");
126-
auto notimestamps_token_id = convert_tokens_to_ids("<|notimestamps|>");
113+
auto translate_token_id = tokenizer_->GetTokenId("<|translate|>");
114+
auto transcribe_token_id = tokenizer_->GetTokenId("<|transcribe|>");
115+
auto notimestamps_token_id = tokenizer_->GetTokenId("<|notimestamps|>");
127116
std::vector<extTokenId_t> ids;
128117
ids.reserve(4);
129118
if (lang != nullptr) {
@@ -133,7 +122,7 @@ OrtxStatus TokenizerImpl::GetDecoderPromptIds(size_t batch_size, const char* lan
133122
}
134123

135124
std::string lang_token = "<|" + lang_str->first + "|>";
136-
ids.push_back(convert_tokens_to_ids(lang_token));
125+
ids.push_back(tokenizer_->GetTokenId(lang_token));
137126
}
138127

139128
if (task != nullptr) {

shared/api/tokenizer_impl.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ class TokenizerImpl : public OrtxObjectImpl {
2727
return BatchDecode(t_ids, t_text);
2828
}
2929

30+
OrtxStatus Token2Id(const std::string& token, extTokenId_t& id) const {
31+
id = tokenizer_->GetTokenId(token);
32+
return {};
33+
}
34+
3035
OrtxStatus Id2Token(extTokenId_t id, std::string& token, std::unique_ptr<BPEDecoderState>& cache) const {
3136
BPEDecoderState* state_ptr = cache.get();
3237
OrtxStatus status = Id2Token(id, token, &state_ptr);
@@ -50,7 +55,6 @@ class TokenizerImpl : public OrtxObjectImpl {
5055
std::vector<std::vector<extTokenId_t>>& t_ids) const;
5156

5257
private:
53-
bool tiktoken = false;
5458
std::string tokenizer_dir_;
5559
std::shared_ptr<ort_extensions::bpe::TokenJsonConfig> tok_config_;
5660
std::unique_ptr<JsonFastTokenizer> tokenizer_;

test/pp_api_test/test_tokenizer.cc

+5-4
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,11 @@ TEST(OrtxTokenizerTest, WhisperTokenizer) {
410410
const extTokenId_t* token_ids = NULL;
411411
OrtxTokenId2DArrayGetItem(prompt_ids.get(), 0, &token_ids, &length);
412412
std::vector<extTokenId_t> ids(token_ids, token_ids + length);
413-
// std::cout << "Prompt IDs: ";
414-
// for (const auto& id : ids) {
415-
// std::cout << id << " ";
416-
// }
417413

418414
EXPECT_EQ(ids, std::vector<extTokenId_t>({50259, 50358, 50363}));
415+
416+
extTokenId_t sot_id{};
417+
err = OrtxConvertTokenToId(tokenizer.get(), "<|startoftranscript|>", &sot_id);
418+
EXPECT_EQ(err, kOrtxOK);
419+
EXPECT_EQ(sot_id, 50258);
419420
}

0 commit comments

Comments
 (0)