Skip to content

Commit 8f2c35f

Browse files
authored
Add more tests for pre-processing C APIs (microsoft#793)
* initial api for tokenizer * More fixings and test data refinement * add a simple wrapper for pre-processing APIs * fix the test issues * test if the tokenizer is spm based * fix the failed test cases * json pointer does not work
1 parent 85ffb94 commit 8f2c35f

File tree

15 files changed

+278
-221
lines changed

15 files changed

+278
-221
lines changed

.pyproject/cmdclass.py

+4
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def initialize_options(self):
147147
self.no_azure = None
148148
self.no_opencv = None
149149
self.cc_debug = None
150+
self.pp_api = None
150151
self.cuda_archs = None
151152
self.ort_pkg_dir = None
152153

@@ -210,6 +211,9 @@ def build_cmake(self, extension):
210211
'-DOCOS_ENABLE_CV2=OFF',
211212
'-DOCOS_ENABLE_VISION=OFF']
212213

214+
if self.pp_api:
215+
cmake_args += ['-DOCOS_ENABLE_C_API=ON']
216+
213217
if self.no_azure is not None:
214218
azure_flag = "OFF" if self.no_azure == 1 else "ON"
215219
cmake_args += ['-DOCOS_ENABLE_AZURE=' + azure_flag]

base/ustring.h

+17-13
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ class ustring : public std::u32string {
1111
public:
1212
ustring() = default;
1313

14-
explicit ustring(const char* str) { assign(FromUTF8(str)); }
14+
explicit ustring(const char* str) { assign(std::move(FromUTF8(str))); }
1515

16-
explicit ustring(const std::string& str) { assign(FromUTF8(str)); }
16+
explicit ustring(const std::string& str) { assign(std::move(FromUTF8(str))); }
1717

18-
explicit ustring(const std::string_view& str) { assign(FromUTF8(str)); }
18+
explicit ustring(const std::string_view& str) { assign(std::move(FromUTF8(str))); }
1919

2020
explicit ustring(const char32_t* str) : std::u32string(str) {}
2121

@@ -76,11 +76,15 @@ class ustring : public std::u32string {
7676
}
7777
}
7878

79-
static bool ValidateUTF8(const std::string& data) {
79+
// return a negative value for the first invalid utf8 char position,
80+
// otherwise the position of the terminating null character, which is the end of the string.
81+
static ptrdiff_t ValidateUTF8(const std::string& data) {
8082
const unsigned char* s = reinterpret_cast<const unsigned char*>(data.c_str());
83+
const unsigned char* s_begin = s;
8184
const unsigned char* s_end = s + data.size();
85+
8286
if (*s_end != '\0')
83-
return false;
87+
return 0;
8488

8589
while (*s) {
8690
if (*s < 0x80)
@@ -89,45 +93,45 @@ class ustring : public std::u32string {
8993
else if ((s[0] & 0xe0) == 0xc0) {
9094
/* 110XXXXx 10xxxxxx */
9195
if (s + 1 >= s_end) {
92-
return false;
96+
return s_begin - s;
9397
}
9498
if ((s[1] & 0xc0) != 0x80 ||
9599
(s[0] & 0xfe) == 0xc0) /* overlong? */
96-
return false;
100+
return s_begin - s;
97101
else
98102
s += 2;
99103
} else if ((s[0] & 0xf0) == 0xe0) {
100104
/* 1110XXXX 10Xxxxxx 10xxxxxx */
101105
if (s + 2 >= s_end) {
102-
return false;
106+
return s_begin - s;
103107
}
104108
if ((s[1] & 0xc0) != 0x80 ||
105109
(s[2] & 0xc0) != 0x80 ||
106110
(s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) || /* overlong? */
107111
(s[0] == 0xed && (s[1] & 0xe0) == 0xa0) || /* surrogate? */
108112
(s[0] == 0xef && s[1] == 0xbf &&
109113
(s[2] & 0xfe) == 0xbe)) /* U+FFFE or U+FFFF? */
110-
return false;
114+
return s_begin - s;
111115
else
112116
s += 3;
113117
} else if ((s[0] & 0xf8) == 0xf0) {
114118
/* 11110XXX 10XXxxxx 10xxxxxx 10xxxxxx */
115119
if (s + 3 >= s_end) {
116-
return false;
120+
return s_begin - s;
117121
}
118122
if ((s[1] & 0xc0) != 0x80 ||
119123
(s[2] & 0xc0) != 0x80 ||
120124
(s[3] & 0xc0) != 0x80 ||
121125
(s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || /* overlong? */
122126
(s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) /* > U+10FFFF? */
123-
return false;
127+
return s_begin - s;
124128
else
125129
s += 4;
126130
} else
127-
return false;
131+
return s_begin - s;
128132
}
129133

130-
return true;
134+
return s - s_begin;
131135
}
132136

133137
private:

docs/c_api.md

+10-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,13 @@ Most APIs accept raw data inputs such as audio, image compressed binary formats,
1818

1919
**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extraction.cc#L16).
2020

21-
NB: If onnxruntime-extensions is to build as a shared library, which requires the OCOS_ENABLE_AUDIO OCOS_ENABLE_CV2 OCOS_ENABLE_OPENCV_CODECS OCOS_ENABLE_GPT2_TOKENIZER build flags are ON to have a full function of binary. Only onnxruntime-extensions static library can be used for a minimal build with the selected operators, so in that case, the shared library build can be switched off by `-DOCOS_BUILD_SHARED_LIB=OFF`.
21+
**NB:** If onnxruntime-extensions is to build as a shared library, which requires the OCOS_ENABLE_AUDIO OCOS_ENABLE_CV2 OCOS_ENABLE_OPENCV_CODECS OCOS_ENABLE_GPT2_TOKENIZER build flags are ON to have a full function of binary. Only onnxruntime-extensions static library can be used for a minimal build with the selected operators, so in that case, the shared library build can be switched off by `-DOCOS_BUILD_SHARED_LIB=OFF`.
22+
23+
There is a simple Python wrapper on these C API in [pp_api](../onnxruntime_extensions/pp_api.py), which can have a easy access these APIs in Python code like
24+
25+
```Python
26+
from onnxruntime_extensions.pp_api import Tokenizer
27+
# the name can be the same one used by Huggingface transformers.AutoTokenizer
28+
pp_tok = Tokenizer('google/gemma-2-2b')
29+
print(pp_tok.tokenize("what are you? \n 给 weiss ich, über was los ist \n"))
30+
```

docs/custom_ops.md

-9
Original file line numberDiff line numberDiff line change
@@ -531,15 +531,6 @@ expect(node, inputs=[inputs],
531531
</details>
532532

533533

534-
### BlingFireSentenceBreaker
535-
536-
TODO
537-
538-
### BpeTokenizer
539-
540-
TODO
541-
542-
543534
## String operators
544535

545536
### StringEqual

docs/development.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The package contains all custom operators and some Python scripts to manipulate
1616
- no-azure: disable AzureOp kernel build in Python package.
1717
- no-opencv: disable operators based on OpenCV in build.
1818
- cc-debug: generate debug info for extensions binaries and disable C/C++ compiler optimization.
19+
- pp_api: enable pre-processing C ABI Python wrapper, `from onnxruntime_extensions.pp_api import *`
1920
- cuda-archs: specify the CUDA architectures(like 70, 85, etc.), and the multiple values can be combined with semicolon. The default value is nvidia-smi util output of GPU-0
2021
- ort\_pkg\_dir: specify ONNXRuntime package directory the extension project is depending on. This is helpful if you want to use some ONNXRuntime latest function which has not been involved in the official build
2122

onnxruntime_extensions/pp_api.py

+60-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,69 @@
33
# license information.
44
###############################################################################
55

6+
import os
67
from . import _extensions_pydll as _C
7-
if not hasattr(_C, "create_processor"):
8-
raise ImportError("onnxruntime_extensions is not built with pre-processing API")
8+
if not hasattr(_C, "delete_object"):
9+
raise ImportError(
10+
"onnxruntime_extensions is not built with pre-processing C API"
11+
"To enable it, please build the package with --ortx-user-option=pp_api")
912

1013
create_processor = _C.create_processor
1114
load_images = _C.load_images
1215
image_pre_process = _C.image_pre_process
1316
tensor_result_get_at = _C.tensor_result_get_at
17+
18+
create_tokenizer = _C.create_tokenizer
19+
batch_tokenize = _C.batch_tokenize
20+
batch_detokenize = _C.batch_detokenize
21+
22+
delete_object = _C.delete_object
23+
24+
25+
class Tokenizer:
26+
def __init__(self, tokenizer_dir):
27+
if os.path.isdir(tokenizer_dir):
28+
self.tokenizer = create_tokenizer(tokenizer_dir)
29+
else:
30+
try:
31+
from transformers.utils import cached_file
32+
resolved_full_file = cached_file(
33+
tokenizer_dir, "tokenizer.json")
34+
resolved_config_file = cached_file(
35+
tokenizer_dir, "tokenizer_config.json")
36+
except ImportError:
37+
raise ValueError(
38+
f"Directory '{tokenizer_dir}' not found and transformers is not available")
39+
if not os.path.exists(resolved_full_file):
40+
raise FileNotFoundError(
41+
f"Downloaded HF file '{resolved_full_file}' cannot be found")
42+
if (os.path.dirname(resolved_full_file) != os.path.dirname(resolved_config_file)):
43+
raise FileNotFoundError(
44+
f"Downloaded HF files '{resolved_full_file}' and '{resolved_config_file}' are not in the same directory")
45+
46+
tokenizer_dir = os.path.dirname(resolved_full_file)
47+
self.tokenizer = create_tokenizer(tokenizer_dir)
48+
49+
def tokenize(self, text):
50+
return batch_tokenize(self.tokenizer, [text])[0]
51+
52+
def detokenize(self, tokens):
53+
return batch_detokenize(self.tokenizer, [tokens])[0]
54+
55+
def __del__(self):
56+
if delete_object and self.tokenizer:
57+
delete_object(self.tokenizer)
58+
self.tokenizer = None
59+
60+
61+
class ImageProcessor:
62+
def __init__(self, processor_json):
63+
self.processor = create_processor(processor_json)
64+
65+
def pre_process(self, images):
66+
return image_pre_process(self.processor, images)
67+
68+
def __del__(self):
69+
if delete_object and self.processor:
70+
delete_object(self.processor)
71+
self.processor = None

operators/tokenizer/bpe_kernels.cc

+24-10
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@ static bool IsBosEosRequired(const std::string& model_name) {
2727
return model_name != kModel_GPT2 && model_name != kModel_CodeGen;
2828
}
2929

30-
static bool IsSpmModel(const std::string& model_name) {
31-
return model_name == kModel_Llama ||
32-
model_name == kModel_Gemma;
33-
}
34-
3530
std::string BpeModelConf::GetSpecialTokens() const {
3631
std::string special_tokens = unk_token_; // unk_token_ is required
3732
auto add_token = [](std::string& sp, const char* tok) {
@@ -145,7 +140,7 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
145140
merges_stream,
146141
bpe_conf_.get().unk_token_,
147142
bpe_conf_.get().GetSpecialTokens().c_str(),
148-
IsSpmModel(ModelName()));
143+
bpe_conf_.get().spm_model_);
149144
if (!status.IsOk()) {
150145
return (OrtStatusPtr)status;
151146
}
@@ -454,7 +449,7 @@ OrtxStatus KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
454449
}
455450

456451
auto tok_fun = &KernelBpeTokenizer::Tokenize;
457-
if (IsSpmModel(ModelName())) {
452+
if (bpe_conf_.get().spm_model_) {
458453
tok_fun = &KernelBpeTokenizer::SpmTokenize;
459454
}
460455

@@ -556,7 +551,8 @@ static const auto kSpmConfiguration = BpeModelConf{
556551
"<unk>", // unk_token
557552
"<s>", // bos_token
558553
"</s>", // eos_token
559-
""}; // pad_token
554+
"", // pad_token
555+
true};
560556

561557
SpmTokenizer::SpmTokenizer()
562558
: KernelBpeTokenizer(kSpmConfiguration) {}
@@ -718,15 +714,33 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::bpe::TokenJsonConfig& c
718714
module_ifs >> tok_json;
719715
} else {
720716
ifs >> tok_json;
717+
// auto decoders_node = tok_json.find("/decoder/decoders"_json_pointer);
718+
auto decoders_node = tok_json.find("decoder");
719+
if (decoders_node != tok_json.end()) {
720+
decoders_node = decoders_node->find("decoders");
721+
}
722+
723+
if (decoders_node->is_array()) {
724+
for(auto step = decoders_node->begin(); step != decoders_node->end(); ++step) {
725+
std::string type = step->value("type", "");
726+
if (type == "Replace") {
727+
std::string target = step->value("/pattern/String"_json_pointer, "");
728+
if (target == "\xe2\x96\x81") {
729+
json_conf_.spm_model_ = true;
730+
break;
731+
}
732+
}
733+
}
734+
}
721735
auto model_node = tok_json.find("model");
722736
if (model_node == tok_json.end()) {
723737
return OrtxStatus(kOrtxErrorCorruptData, "Failed to get model node from tokenizer.json");
724738
}
725739

726740
bbpe_tokenizer_ = std::make_unique<BpeModel>();
727741
status = bbpe_tokenizer_->Load(*model_node,
728-
bpe_conf_.get().GetSpecialTokens().c_str(),
729-
IsSpmModel(ModelName()));
742+
bpe_conf_.get().GetSpecialTokens().c_str(),
743+
bpe_conf_.get().spm_model_);
730744
}
731745

732746

operators/tokenizer/bpe_kernels.h

+5-22
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct BpeModelConf {
2020
const char* eos_token_{"<|endoftext|>"};
2121
const char* pad_token_{nullptr};
2222

23+
bool spm_model_{};
2324
std::string GetSpecialTokens() const;
2425
};
2526

@@ -108,10 +109,6 @@ struct SpmTokenizer : KernelBpeTokenizer {
108109
class JsonFastTokenizer : public KernelBpeTokenizer {
109110
public:
110111
JsonFastTokenizer();
111-
bool tiktoken_ = false;
112-
std::string unicode_byte_encoder_[256] = {};
113-
void CreateUnicodeByteEncoder();
114-
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
115112
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
116113
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
117114
ortc::Tensor<int64_t>& tokenize_output,
@@ -121,28 +118,14 @@ class JsonFastTokenizer : public KernelBpeTokenizer {
121118
public:
122119
const auto& GetAddedTokens() const { return added_tokens_; }
123120
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }
121+
bool IsSpmModel() const { return json_conf_.spm_model_; }
122+
bool tiktoken_ = false;
124123

125124
private:
126-
BpeModelConf json_conf_;
127-
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
128-
};
129-
130-
class TikTokenizer : KernelBpeTokenizer {
131-
public:
132-
TikTokenizer();
125+
void CreateUnicodeByteEncoder();
133126
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
134-
OrtxStatus Load(const ort_extensions::bpe::TokenJsonConfig& config);
135-
OrtxStatus Compute(const ortc::Tensor<std::string>& input,
136-
ortc::Tensor<int64_t>& tokenize_output,
137-
std::optional<ortc::Tensor<int64_t>*> attention_mask,
138-
std::optional<ortc::Tensor<int64_t>*> offset_mapping) const;
139127

140-
public:
141-
const auto& GetAddedTokens() const { return added_tokens_; }
142-
const ort_extensions::BpeModel& GetEncoder() const { return *bbpe_tokenizer_; }
143-
144-
private:
145-
std::unique_ptr<ort_extensions::BpeModel>bbpe_tokenizer_;
146128
BpeModelConf json_conf_;
147129
std::vector<ort_extensions::bpe::AddedToken> added_tokens_;
130+
std::string unicode_byte_encoder_[256] = {};
148131
};

0 commit comments

Comments
 (0)