Skip to content

Commit 01c4164

Browse files
Moved SpecialTokens assignment after the modification to avoid "Collection Modified" error (#7328)
* Moved special tokens assignment below so the collection won't be modified * Added safe dictionary inversion * Added storing the not-normalized special tokens * Added support for net standard * Added and updated tests * Updated without additional memory allocation
1 parent fb7cc25 commit 01c4164

File tree

3 files changed

+123
-21
lines changed

3 files changed

+123
-21
lines changed

src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs

+29-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

@@ -762,15 +762,16 @@ private static BertTokenizer Create(
762762

763763
options.Normalizer ??= options.ApplyBasicTokenization ? new BertNormalizer(options.LowerCaseBeforeTokenization, options.IndividuallyTokenizeCjk, options.RemoveNonSpacingMarks) : null;
764764

765+
IReadOnlyDictionary<string, int>? specialTokensDict = options.SpecialTokens;
765766
if (options.SplitOnSpecialTokens)
766767
{
767768
bool lowerCase = options.ApplyBasicTokenization && options.LowerCaseBeforeTokenization;
768769
if (options.SpecialTokens is not null)
769770
{
770771
if (lowerCase)
771772
{
772-
Dictionary<string, int> dic = options.SpecialTokens.ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
773-
options.SpecialTokens = dic;
773+
Dictionary<string, int> tempSpecialTokens = [];
774+
specialTokensDict = tempSpecialTokens;
774775

775776
foreach (var kvp in options.SpecialTokens)
776777
{
@@ -779,37 +780,49 @@ private static BertTokenizer Create(
779780
throw new ArgumentException($"The special token '{kvp.Key}' is not in the vocabulary or assigned id value {id} different than the value {kvp.Value} in the special tokens.");
780781
}
781782

782-
// Ensure that the special tokens are lowercased.
783-
dic[kvp.Key.ToLowerInvariant()] = kvp.Value;
783+
// Add the special token into our dictionary, normalizing it, and adding it into the
784+
// main vocab, if needed.
785+
AddSpecialToken(vocab, tempSpecialTokens, kvp.Key, true);
784786
}
785787
}
786788
}
787789
else
788790
{
789-
// Create a dictionary with the special tokens.
790-
Dictionary<string, int> specialTokens = new Dictionary<string, int>();
791-
options.SpecialTokens = specialTokens;
792-
793-
AddSpecialToken(vocab, specialTokens, options.UnknownToken, lowerCase);
794-
AddSpecialToken(vocab, specialTokens, options.SeparatorToken, lowerCase);
795-
AddSpecialToken(vocab, specialTokens, options.PaddingToken, lowerCase);
796-
AddSpecialToken(vocab, specialTokens, options.ClassificationToken, lowerCase);
797-
AddSpecialToken(vocab, specialTokens, options.MaskingToken, lowerCase);
791+
// Create a dictionary with the special tokens - store the un-normalized forms in the options as
792+
// that field is exposed to the public. In addition, store the normalized form for creating the
793+
// pre-tokenizer.
794+
Dictionary<string, int> tempSpecialTokens = [];
795+
Dictionary<string, int> notNormalizedSpecialTokens = [];
796+
AddSpecialToken(vocab, tempSpecialTokens, options.UnknownToken, lowerCase, notNormalizedSpecialTokens);
797+
AddSpecialToken(vocab, tempSpecialTokens, options.SeparatorToken, lowerCase, notNormalizedSpecialTokens);
798+
AddSpecialToken(vocab, tempSpecialTokens, options.PaddingToken, lowerCase, notNormalizedSpecialTokens);
799+
AddSpecialToken(vocab, tempSpecialTokens, options.ClassificationToken, lowerCase, notNormalizedSpecialTokens);
800+
AddSpecialToken(vocab, tempSpecialTokens, options.MaskingToken, lowerCase, notNormalizedSpecialTokens);
801+
802+
options.SpecialTokens = notNormalizedSpecialTokens;
803+
specialTokensDict = tempSpecialTokens;
798804
}
799805
}
800806

801-
options.PreTokenizer ??= options.ApplyBasicTokenization ? PreTokenizer.CreateWordOrPunctuation(options.SplitOnSpecialTokens ? options.SpecialTokens : null) : PreTokenizer.CreateWhiteSpace();
807+
// We set the PreTokenizer here using the normalized special tokens dict (if relevant), and therefore we can
808+
// keep the not-normalized special tokens dict in the options passed to the WordPieceTokenizer.
809+
options.PreTokenizer ??= options.ApplyBasicTokenization ? PreTokenizer.CreateWordOrPunctuation(options.SplitOnSpecialTokens ? specialTokensDict : null) : PreTokenizer.CreateWhiteSpace();
802810

803811
return new BertTokenizer(vocab, vocabReverse, options);
804812
}
805813

806-
private static void AddSpecialToken(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<string, int> specialTokens, string token, bool lowerCase)
814+
private static void AddSpecialToken(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<string, int> specialTokens, string token, bool lowerCase, Dictionary<string, int>? notNormalizedSpecialTokens = null)
807815
{
808816
if (token is null || !vocab.TryGetValue(new StringSpanOrdinalKey(token), out int id))
809817
{
810818
throw new ArgumentException($"The special token '{token}' is not in the vocabulary.");
811819
}
812820

821+
if (notNormalizedSpecialTokens is not null)
822+
{
823+
notNormalizedSpecialTokens[token] = id;
824+
}
825+
813826
string normalizedToken = token;
814827
if (lowerCase)
815828
{

src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

@@ -42,7 +42,7 @@ internal WordPieceTokenizer(
4242
options ??= new();
4343

4444
SpecialTokens = options.SpecialTokens;
45-
SpecialTokensReverse = options.SpecialTokens is not null ? options.SpecialTokens.ToDictionary(kvp => kvp.Value, kvp => kvp.Key) : null;
45+
SpecialTokensReverse = options.SpecialTokens is not null ? options.SpecialTokens.GroupBy(kvp => kvp.Value).ToDictionary(g => g.Key, g => g.First().Key) : null;
4646

4747
if (options.UnknownToken is null)
4848
{
@@ -800,4 +800,4 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool
800800
return OperationStatus.Done;
801801
}
802802
}
803-
}
803+
}

test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs

+91-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

@@ -14,6 +14,91 @@ namespace Microsoft.ML.Tokenizers.Tests
1414
{
1515
public class BertTokenizerTests
1616
{
17+
[Fact]
18+
public void TestWithLowerCasingExplicitSpecialTokens()
19+
{
20+
// Add [SPECIAL] token at end (to keep indices as is)
21+
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12, 13
22+
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "[SPECIAL]"];
23+
24+
string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);
25+
26+
Dictionary<string, int> specialTokens = new() {
27+
{ "[PAD]", 0 },
28+
{ "[UNK]", 1 },
29+
{ "[CLS]", 2 },
30+
{ "[SEP]", 3 },
31+
{ "[MASK]", 4 },
32+
{ "[SPECIAL]", 13 },
33+
};
34+
var bertOptions = new BertOptions()
35+
{
36+
SpecialTokens = specialTokens
37+
};
38+
39+
try
40+
{
41+
using Stream vocabStream = File.OpenRead(vocabFile);
42+
BertTokenizer[] bertTokenizers = [BertTokenizer.Create(vocabFile, bertOptions), BertTokenizer.Create(vocabStream, bertOptions)];
43+
44+
foreach (var tokenizer in bertTokenizers)
45+
{
46+
Assert.NotNull(tokenizer.PreTokenizer);
47+
Assert.Equal("[UNK]", tokenizer.UnknownToken);
48+
Assert.Equal(1, tokenizer.UnknownTokenId);
49+
Assert.NotNull(tokenizer.Normalizer);
50+
Assert.NotNull(tokenizer.PreTokenizer);
51+
52+
Assert.True(tokenizer.SpecialTokens!.ContainsKey("[SPECIAL]"));
53+
54+
string text = "Hello, How are you [SPECIAL]?";
55+
var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
56+
Assert.Equal("hello, how are you [special]?", normalizedText);
57+
58+
Assert.Equal(
59+
[
60+
new EncodedToken(8, "hello", new Range(0, 5)),
61+
new EncodedToken(6, ",", new Range(5, 6)),
62+
new EncodedToken(10, "how", new Range(7, 10)),
63+
new EncodedToken(11, "are", new Range(11, 14)),
64+
new EncodedToken(12, "you", new Range(15, 18)),
65+
new EncodedToken(13, "[SPECIAL]", new Range(19, 28)),
66+
new EncodedToken(7, "?", new Range(28, 29))
67+
],
68+
tokens);
69+
70+
var ids = tokenizer.EncodeToIds(text);
71+
Assert.Equal([tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId], ids);
72+
73+
Assert.Equal("[CLS] hello, how are you [SPECIAL]? [SEP]", tokenizer.Decode(ids));
74+
Assert.Equal("hello, how are you?", tokenizer.Decode(ids, skipSpecialTokens: true));
75+
76+
tokens = tokenizer.EncodeToTokens(tokenizer.Decode(ids), out normalizedText);
77+
Assert.Equal("[cls] hello, how are you [special]? [sep]", normalizedText);
78+
Assert.Equal(
79+
[
80+
new EncodedToken(2, "[CLS]", new Range(0, 5)),
81+
new EncodedToken(8, "hello", new Range(6, 11)),
82+
new EncodedToken(6, ",", new Range(11, 12)),
83+
new EncodedToken(10, "how", new Range(13, 16)),
84+
new EncodedToken(11, "are", new Range(17, 20)),
85+
new EncodedToken(12, "you", new Range(21, 24)),
86+
new EncodedToken(13, "[SPECIAL]", new Range(25, 34)),
87+
new EncodedToken(7, "?", new Range(34, 35)),
88+
new EncodedToken(3, "[SEP]", new Range(36, 41))
89+
],
90+
tokens);
91+
92+
ids = tokenizer.EncodeToIds(normalizedText!);
93+
Assert.Equal([tokenizer.ClassificationTokenId, tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId, tokenizer.SeparatorTokenId], ids);
94+
}
95+
}
96+
finally
97+
{
98+
File.Delete(vocabFile);
99+
}
100+
}
101+
17102
[Fact]
18103
public void TestWithLowerCasing()
19104
{
@@ -35,6 +120,10 @@ public void TestWithLowerCasing()
35120
Assert.NotNull(tokenizer.Normalizer);
36121
Assert.NotNull(tokenizer.PreTokenizer);
37122

123+
// Make sure the SpecialTokens dictionary contains the not-normalized tokens
124+
Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.UnknownToken));
125+
Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.ClassificationToken));
126+
38127
string text = "Hello, How are you?";
39128
var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
40129
Assert.Equal("hello, how are you?", normalizedText);
@@ -511,4 +600,4 @@ public void TestCreateTokenTypeIdsFromSequences()
511600
}
512601
}
513602
}
514-
}
603+
}

0 commit comments

Comments
 (0)