Skip to content

Commit

Permalink
optimize spm tokenizer for long text (microsoft#799)
Browse files Browse the repository at this point in the history
* optimize spm tokenizer for long text

* refine the split logic

* re-trigger CI pipeline.
  • Loading branch information
wenbingl authored Aug 30, 2024
1 parent 6f53237 commit b8b2ebf
Showing 1 changed file with 53 additions and 21 deletions.
74 changes: 53 additions & 21 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,53 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,
// Get byte encodings prior to performing BPE
std::list<std::pair<uint32_t, uint32_t>> byte_list;

while (res.size() < max_length && char_pos < ustr.length()) {
while (res.size() < max_length && char_pos <= ustr.length()) {
bool split_now = false;
if (char_pos == ustr.length()) {
split_now = true;
}

// temporary split logic, will be replaced regex based split after it is implemented
if (!split_now && byte_list.size() > 10) {
auto is_split_char = [](char32_t ch) {
return ch == U' ' || ch == U'\n' || ch == U'\r' || ch == U'';
};
if (!is_split_char(ustr[char_pos - 1]) && is_split_char(ustr[char_pos])) {
split_now = true;
}
// split immediately to avoid too long byte_list for extreme cases, which is slow.
if (!split_now && byte_list.size() > 100) {
split_now = true;
}
}

if (split_now) {
// Perform BPE
bbpe_tokenizer_->PerformBPE(byte_list);

// Add output to result
for (auto p : byte_list) {
if (res.size() >= max_length) {
break;
}

res.push_back(p.first);

if (compute_offset_mapping) {
offset_mapping.emplace_back(std::make_pair(
offset,
ort_extensions::narrow<size_t>(offset + (size_t)p.second)));
offset += ((size_t)p.second);
}
}

byte_list.clear();
}

if (char_pos == ustr.length()) {
break;
}

auto chr = ustr[char_pos];
if (chr == U' ') {
chr = 0x2581; // UTF-8 string '\xe2\x96\x81'
Expand All @@ -436,33 +482,19 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,

char_pos++;
}
{
// Perform BPE
bbpe_tokenizer_->PerformBPE(byte_list);

// Add output to result
for (auto p : byte_list) {
if (res.size() >= max_length) {
break;
}

res.push_back(p.first);

if (compute_offset_mapping) {
offset_mapping.emplace_back(std::make_pair(
offset,
ort_extensions::narrow<size_t>(offset + (size_t)p.second)));
offset += ((size_t)p.second);
}
}
}

if (compute_offset_mapping) {
// Add offset mappings for input in this instance to list of offset mappings for all inputs
offset_map.emplace_back(offset_mapping);
}
}

if (res.size() > 0 && res.front() == bos_token_id_) {
if (add_bos_token_.has_value() && add_bos_token_.value() == false) {
res.erase(res.begin());
}
}

return res;
}

Expand Down

0 comments on commit b8b2ebf

Please sign in to comment.