Skip to content

Commit

Permalink
Fix batching in fairseq SentencepieceTokenizer (#640)
Browse files Browse the repository at this point in the history
* Move fairseq fix out of the loop.

* Tidying up the patch.
  • Loading branch information
Craigacp authored Jan 30, 2024
1 parent 44e494b commit d47a3dd
Showing 1 changed file with 23 additions and 22 deletions.
45 changes: 23 additions & 22 deletions operators/tokenizer/sentencepiece_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,32 +83,33 @@ OrtStatusPtr KernelSentencepieceTokenizer::Compute(const ortc::Tensor<std::strin
content.push_back(tokenizer_.eos_id());
token_indices.push_back(ort_extensions::narrow<int32_t>(str_input[i].length()));
}

if (fairseq.has_value() && (*fairseq)) {
// HF Fairseq Example (XLMRobertaTokenizer) : https://huggingface.co/transformers/v4.6.0/_modules/transformers/models/xlm_roberta/tokenization_xlm_roberta.html#XLMRobertaTokenizer
//
// Original fairseq vocab and spm vocab must be "aligned":
// Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
// -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
// fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
// spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
//
// As per HF, the first "real" token "," has position 4 in the XLMRobertaTokenizer vocab and position
// 3 in the SPM vocab, so we add a padding value of 1 to IDs, and fix exceptions for '<unk>' and '<s>'.
std::for_each(content.begin(), content.end(), [](int& n) {
if (n == 0) { // '<unk>': 0 -> 3
n = 3;
} else if (n == 1) { // '<s>': 1 -> 0
n = 0;
} else if (n != 2) { // '</s>': 2 -> 2, '<*>': x -> x + 1
n++;
}
});
}
}
}
instance_indices.push_back(content.size());

// Patch fairseq indices
if (fairseq.has_value() && (*fairseq) && !add_rev) {
// HF Fairseq Example (XLMRobertaTokenizer) : https://huggingface.co/transformers/v4.6.0/_modules/transformers/models/xlm_roberta/tokenization_xlm_roberta.html#XLMRobertaTokenizer
//
// Original fairseq vocab and spm vocab must be "aligned":
// Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
// -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
// fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
// spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
//
// As per HF, the first "real" token "," has position 4 in the XLMRobertaTokenizer vocab and position
// 3 in the SPM vocab, so we add a padding value of 1 to IDs, and fix exceptions for '<unk>' and '<s>'.
std::for_each(content.begin(), content.end(), [](int& n) {
if (n == 0) { // '<unk>': 0 -> 3
n = 3;
} else if (n == 1) { // '<s>': 1 -> 0
n = 0;
} else if (n != 2) { // '</s>': 2 -> 2, '<*>': x -> x + 1
n++;
}
});
}

// Setup output
std::vector<int64_t> size_content(1);
size_content[0] = content.size();
Expand Down

0 comments on commit d47a3dd

Please sign in to comment.