1
- // Licensed to the .NET Foundation under one or more agreements.
1
+ // Licensed to the .NET Foundation under one or more agreements.
2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
@@ -14,6 +14,91 @@ namespace Microsoft.ML.Tokenizers.Tests
14
14
{
15
15
public class BertTokenizerTests
16
16
{
17
+ [ Fact ]
18
+ public void TestWithLowerCasingExplicitSpecialTokens ( )
19
+ {
20
+ // Add [SPECIAL] token at end (to keep indices as is)
21
+ // Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12, 13
22
+ string [ ] vocabTokens = [ "[PAD]" , "[UNK]" , "[CLS]" , "[SEP]" , "[MASK]" , "!" , "," , "?" , "hello" , "world" , "how" , "are" , "you" , "[SPECIAL]" ] ;
23
+
24
+ string vocabFile = WordPieceTests . CreateVocabFile ( vocabTokens ) ;
25
+
26
+ Dictionary < string , int > specialTokens = new ( ) {
27
+ { "[PAD]" , 0 } ,
28
+ { "[UNK]" , 1 } ,
29
+ { "[CLS]" , 2 } ,
30
+ { "[SEP]" , 3 } ,
31
+ { "[MASK]" , 4 } ,
32
+ { "[SPECIAL]" , 13 } ,
33
+ } ;
34
+ var bertOptions = new BertOptions ( )
35
+ {
36
+ SpecialTokens = specialTokens
37
+ } ;
38
+
39
+ try
40
+ {
41
+ using Stream vocabStream = File . OpenRead ( vocabFile ) ;
42
+ BertTokenizer [ ] bertTokenizers = [ BertTokenizer . Create ( vocabFile , bertOptions ) , BertTokenizer . Create ( vocabStream , bertOptions ) ] ;
43
+
44
+ foreach ( var tokenizer in bertTokenizers )
45
+ {
46
+ Assert . NotNull ( tokenizer . PreTokenizer ) ;
47
+ Assert . Equal ( "[UNK]" , tokenizer . UnknownToken ) ;
48
+ Assert . Equal ( 1 , tokenizer . UnknownTokenId ) ;
49
+ Assert . NotNull ( tokenizer . Normalizer ) ;
50
+ Assert . NotNull ( tokenizer . PreTokenizer ) ;
51
+
52
+ Assert . True ( tokenizer . SpecialTokens ! . ContainsKey ( "[SPECIAL]" ) ) ;
53
+
54
+ string text = "Hello, How are you [SPECIAL]?" ;
55
+ var tokens = tokenizer . EncodeToTokens ( text , out string ? normalizedText ) ;
56
+ Assert . Equal ( "hello, how are you [special]?" , normalizedText ) ;
57
+
58
+ Assert . Equal (
59
+ [
60
+ new EncodedToken ( 8 , "hello" , new Range ( 0 , 5 ) ) ,
61
+ new EncodedToken ( 6 , "," , new Range ( 5 , 6 ) ) ,
62
+ new EncodedToken ( 10 , "how" , new Range ( 7 , 10 ) ) ,
63
+ new EncodedToken ( 11 , "are" , new Range ( 11 , 14 ) ) ,
64
+ new EncodedToken ( 12 , "you" , new Range ( 15 , 18 ) ) ,
65
+ new EncodedToken ( 13 , "[SPECIAL]" , new Range ( 19 , 28 ) ) ,
66
+ new EncodedToken ( 7 , "?" , new Range ( 28 , 29 ) )
67
+ ] ,
68
+ tokens ) ;
69
+
70
+ var ids = tokenizer . EncodeToIds ( text ) ;
71
+ Assert . Equal ( [ tokenizer . ClassificationTokenId , 8 , 6 , 10 , 11 , 12 , 13 , 7 , tokenizer . SeparatorTokenId ] , ids ) ;
72
+
73
+ Assert . Equal ( "[CLS] hello, how are you [SPECIAL]? [SEP]" , tokenizer . Decode ( ids ) ) ;
74
+ Assert . Equal ( "hello, how are you?" , tokenizer . Decode ( ids , skipSpecialTokens : true ) ) ;
75
+
76
+ tokens = tokenizer . EncodeToTokens ( tokenizer . Decode ( ids ) , out normalizedText ) ;
77
+ Assert . Equal ( "[cls] hello, how are you [special]? [sep]" , normalizedText ) ;
78
+ Assert . Equal (
79
+ [
80
+ new EncodedToken ( 2 , "[CLS]" , new Range ( 0 , 5 ) ) ,
81
+ new EncodedToken ( 8 , "hello" , new Range ( 6 , 11 ) ) ,
82
+ new EncodedToken ( 6 , "," , new Range ( 11 , 12 ) ) ,
83
+ new EncodedToken ( 10 , "how" , new Range ( 13 , 16 ) ) ,
84
+ new EncodedToken ( 11 , "are" , new Range ( 17 , 20 ) ) ,
85
+ new EncodedToken ( 12 , "you" , new Range ( 21 , 24 ) ) ,
86
+ new EncodedToken ( 13 , "[SPECIAL]" , new Range ( 25 , 34 ) ) ,
87
+ new EncodedToken ( 7 , "?" , new Range ( 34 , 35 ) ) ,
88
+ new EncodedToken ( 3 , "[SEP]" , new Range ( 36 , 41 ) )
89
+ ] ,
90
+ tokens ) ;
91
+
92
+ ids = tokenizer . EncodeToIds ( normalizedText ! ) ;
93
+ Assert . Equal ( [ tokenizer . ClassificationTokenId , tokenizer . ClassificationTokenId , 8 , 6 , 10 , 11 , 12 , 13 , 7 , tokenizer . SeparatorTokenId , tokenizer . SeparatorTokenId ] , ids ) ;
94
+ }
95
+ }
96
+ finally
97
+ {
98
+ File . Delete ( vocabFile ) ;
99
+ }
100
+ }
101
+
17
102
[ Fact ]
18
103
public void TestWithLowerCasing ( )
19
104
{
@@ -35,6 +120,10 @@ public void TestWithLowerCasing()
35
120
Assert . NotNull ( tokenizer . Normalizer ) ;
36
121
Assert . NotNull ( tokenizer . PreTokenizer ) ;
37
122
123
+ // Make sure the SpecialTokens dictionary contains the not-normalized tokens
124
+ Assert . True ( tokenizer . SpecialTokens ! . ContainsKey ( tokenizer . UnknownToken ) ) ;
125
+ Assert . True ( tokenizer . SpecialTokens ! . ContainsKey ( tokenizer . ClassificationToken ) ) ;
126
+
38
127
string text = "Hello, How are you?" ;
39
128
var tokens = tokenizer . EncodeToTokens ( text , out string ? normalizedText ) ;
40
129
Assert . Equal ( "hello, how are you?" , normalizedText ) ;
@@ -511,4 +600,4 @@ public void TestCreateTokenTypeIdsFromSequences()
511
600
}
512
601
}
513
602
}
514
- }
603
+ }
0 commit comments