1
+ import os
2
+ import pickle
3
+ import tempfile
4
+ import numpy as np
5
+ import sentencepiece as spm
6
+ import tiktoken
7
+ from tqdm import tqdm # For progress bars
8
+ from transformers import AutoTokenizer
9
+
10
+ class Tokenizer :
11
+ def __init__ (self , args ):
12
+ self .args = args
13
+
14
+ def tokenize (self , data ):
15
+ raise NotImplementedError ("Tokenize method must be implemented by subclasses." )
16
+
17
+ def detokenize (self , ids ):
18
+ raise NotImplementedError ("Detokenize method must be implemented by subclasses." )
19
+
20
+ def save_meta (self , meta ):
21
+ with open ("meta.pkl" , "wb" ) as f :
22
+ pickle .dump (meta , f )
23
+
24
+ @staticmethod
25
+ def get_key_from_meta (keyname ):
26
+ meta_path = 'meta.pkl'
27
+ if os .path .exists (meta_path ):
28
+ with open (meta_path , 'rb' ) as f :
29
+ meta = pickle .load (f )
30
+ return meta .get (keyname )
31
+ return None
32
+
33
+
34
+ class NumericRangeTokenizer (Tokenizer ):
35
+ def __init__ (self , args ):
36
+ super ().__init__ (args )
37
+ self .min_token = args .min_token
38
+ self .max_token = args .max_token
39
+ self .stoi = None
40
+ self .itos = None
41
+
42
+ def tokenize (self , data ):
43
+ tokens = []
44
+ encountered_tokens = set ()
45
+ lines = data .strip ().split ('\n ' )
46
+ for line in tqdm (lines , desc = "Tokenizing Numeric Range" ):
47
+ try :
48
+ num = int (line )
49
+ if self .min_token <= num <= self .max_token :
50
+ tokens .append (num )
51
+ encountered_tokens .add (num )
52
+ else :
53
+ print (f"Warning: Number { num } is outside the specified range and will be skipped." )
54
+ except ValueError :
55
+ print (f"Warning: Invalid number '{ line } ' will be skipped." )
56
+
57
+ all_tokens = list (range (self .max_token , - 1 , - 1 ))
58
+ self .stoi = {str (num ): i for i , num in enumerate (all_tokens )}
59
+ self .itos = {i : str (num ) for i , num in enumerate (all_tokens )}
60
+
61
+ indexed_tokens = [self .stoi [str (token )] for token in tokens ]
62
+ meta = {
63
+ "vocab_size" : len (self .stoi ),
64
+ "tokenizer" : "numeric_range" ,
65
+ "min_token" : self .min_token ,
66
+ "max_token" : self .max_token ,
67
+ "stoi" : self .stoi ,
68
+ "itos" : self .itos ,
69
+ "encountered_tokens" : sorted (encountered_tokens , reverse = True )
70
+ }
71
+ self .save_meta (meta )
72
+ return indexed_tokens
73
+
74
+ def detokenize (self , ids ):
75
+ return '\n ' .join ([self .itos [id ] for id in ids ])
76
+
77
+
78
+ class SentencePieceTokenizer (Tokenizer ):
79
+ def __init__ (self , args , input_files = None ):
80
+ super ().__init__ (args )
81
+ self .vocab_size = args .vocab_size
82
+ self .spm_model_file = args .spm_model_file
83
+ self .spm_vocab_file = args .spm_vocab_file
84
+ self .skip_tokenization = args .skip_tokenization
85
+ self .input_files = input_files
86
+ self .sp = None
87
+
88
+ if self .spm_model_file :
89
+ self .sp = spm .SentencePieceProcessor ()
90
+ self .sp .load (self .spm_model_file )
91
+ elif input_files :
92
+ self .sp = self .train_sentencepiece_model ()
93
+
94
+ def train_sentencepiece_model (self ):
95
+ spm_model_prefix = "trained_spm_model"
96
+ num_threads = os .cpu_count ()
97
+ input_arg = ""
98
+ if isinstance (self .input_files , list ):
99
+ with tempfile .NamedTemporaryFile (delete = False , mode = "w" ) as tmpfile :
100
+ for input_file in self .input_files :
101
+ with open (input_file , "r" ) as infile :
102
+ tmpfile .write (infile .read ())
103
+ input_arg = tmpfile .name
104
+ else :
105
+ input_arg = self .input_files
106
+
107
+ spm .SentencePieceTrainer .train (
108
+ num_threads = num_threads ,
109
+ user_defined_symbols = "\n , " ,
110
+ input = input_arg ,
111
+ model_prefix = spm_model_prefix ,
112
+ split_digits = True ,
113
+ vocab_size = self .vocab_size ,
114
+ model_type = "bpe" ,
115
+ )
116
+ print ("SentencePiece model training complete." )
117
+
118
+ if isinstance (self .input_files , list ):
119
+ os .remove (input_arg )
120
+
121
+ sp = spm .SentencePieceProcessor ()
122
+ sp .load (f"{ spm_model_prefix } .model" )
123
+ return sp
124
+
125
+ def tokenize (self , data ):
126
+ if not self .sp :
127
+ raise ValueError ("SentencePiece model is not loaded." )
128
+ ids = self .sp .encode_as_ids (data )
129
+ stoi = {self .sp .id_to_piece (id ): id for id in range (self .sp .GetPieceSize ())}
130
+ itos = {id : self .sp .id_to_piece (id ) for id in range (self .sp .GetPieceSize ())}
131
+
132
+ meta = {
133
+ "vocab_size" : self .sp .GetPieceSize (),
134
+ "tokenizer" : "sentencepiece" ,
135
+ "stoi" : stoi ,
136
+ "itos" : itos ,
137
+ }
138
+ self .save_meta (meta )
139
+ return ids
140
+
141
+ def detokenize (self , ids ):
142
+ if not self .sp :
143
+ raise ValueError ("SentencePiece model is not loaded." )
144
+ return self .sp .decode_ids (ids )
145
+
146
+
147
+ class TiktokenTokenizer (Tokenizer ):
148
+ def __init__ (self , args ):
149
+ super ().__init__ (args )
150
+ self .tiktoken_encoding = args .tiktoken_encoding
151
+ self .enc = tiktoken .get_encoding (self .tiktoken_encoding )
152
+ self .vocab_size = self .enc .n_vocab
153
+
154
+ def tokenize (self , data ):
155
+ ids = self .enc .encode_ordinary (data )
156
+ meta = {
157
+ "vocab_size" : self .vocab_size ,
158
+ "tokenizer" : "tiktoken" ,
159
+ "tiktoken_encoding" : self .tiktoken_encoding ,
160
+ }
161
+ self .save_meta (meta )
162
+ return ids
163
+
164
+ def detokenize (self , ids ):
165
+ return self .enc .decode (ids )
166
+
167
+
168
+ class CustomTokenizer (Tokenizer ):
169
+ def __init__ (self , args ):
170
+ super ().__init__ (args )
171
+ if args .tokens_file is None :
172
+ raise ValueError ("Tokens file must be provided for custom tokenization method." )
173
+ with open (args .tokens_file , "r" ) as f :
174
+ self .tokens = [line .strip () for line in f .readlines () if line .strip ()]
175
+ self .tokens = [token .replace ("\\ n" , "\n " ).replace ("\\ t" , "\t " ) for token in self .tokens ]
176
+ self .stoi = {token : i for i , token in enumerate (self .tokens )}
177
+ self .itos = {i : token for i , token in enumerate (self .tokens )}
178
+
179
+ def tokenize (self , data ):
180
+ encoded_data = []
181
+ i = 0
182
+ covered_chars = 0
183
+ data_len = len (data )
184
+ pbar = tqdm (total = data_len , desc = "Tokenizing Custom Tokens" )
185
+ while i < data_len :
186
+ matched = False
187
+ for token in self .tokens :
188
+ token_len = len (token )
189
+ if data .startswith (token , i ):
190
+ encoded_data .append (self .stoi [token ])
191
+ i += token_len
192
+ covered_chars += token_len
193
+ pbar .update (token_len )
194
+ matched = True
195
+ break
196
+ if not matched :
197
+ i += 1 # Skip character if no token matches
198
+ pbar .update (1 )
199
+ pbar .close ()
200
+ coverage = covered_chars / data_len
201
+ print (f"Data coverage by tokens: { coverage * 100 :.2f} %" )
202
+ meta = {"vocab_size" : len (self .tokens ), "stoi" : self .stoi , "itos" : self .itos }
203
+ self .save_meta (meta )
204
+ return encoded_data
205
+
206
+ def detokenize (self , ids ):
207
+ return '' .join ([self .itos [id ] for id in ids ])
208
+
209
+ class CharTokenizer (Tokenizer ):
210
+ def __init__ (self , args , train_data , val_data ):
211
+ super ().__init__ (args )
212
+ self .reuse_chars = args .reuse_chars
213
+ if self .reuse_chars :
214
+ self .chars = self .get_key_from_meta ('chars' )
215
+ if self .chars is None :
216
+ raise ValueError ("No chars found in meta.pkl. Cannot reuse chars." )
217
+ else :
218
+ self .chars = sorted (list (set (train_data + (val_data if val_data else "" ))))
219
+ print (f"All unique characters: { '' .join (self .chars )} " )
220
+ print (f"Vocab size: { len (self .chars )} " )
221
+ self .stoi = {ch : i for i , ch in enumerate (self .chars )}
222
+ self .itos = {i : ch for i , ch in enumerate (self .chars )}
223
+
224
+ def tokenize (self , data ):
225
+ data_len = len (data )
226
+ ids = []
227
+ pbar = tqdm (total = data_len , desc = "Tokenizing Characters" )
228
+ for ch in data :
229
+ ids .append (self .stoi [ch ])
230
+ pbar .update (1 )
231
+ pbar .close ()
232
+ meta = {"vocab_size" : len (self .chars ), "itos" : self .itos , "stoi" : self .stoi , "chars" : self .chars }
233
+ self .save_meta (meta )
234
+ return ids
235
+
236
+ def detokenize (self , ids ):
237
+ return '' .join ([self .itos [id ] for id in ids ])
238
+
239
+
240
+ class CustomCharTokenizerWithByteFallback (Tokenizer ):
241
+ def __init__ (self , args ):
242
+ super ().__init__ (args )
243
+ if args .custom_chars_file is None :
244
+ raise ValueError ("Custom characters file must be provided for this tokenizer." )
245
+ with open (args .custom_chars_file , "r" , encoding = "utf-8" ) as f :
246
+ self .custom_chars = [line .strip () for line in f if line .strip ()]
247
+
248
+ # Build vocab
249
+ self .build_vocab ()
250
+
251
+ def build_vocab (self ):
252
+ # Assign IDs to custom characters
253
+ self .stoi = {ch : i for i , ch in enumerate (self .custom_chars )}
254
+ self .itos = {i : ch for i , ch in enumerate (self .custom_chars )}
255
+ self .custom_char_count = len (self .custom_chars )
256
+
257
+ # Assign IDs to bytes (0-255)
258
+ self .byte_stoi = {byte : i + self .custom_char_count for i , byte in enumerate (range (256 ))}
259
+ self .byte_itos = {i + self .custom_char_count : byte for i , byte in enumerate (range (256 ))}
260
+
261
+ # Update total vocab size
262
+ self .vocab_size = self .custom_char_count + 256 # 256 bytes
263
+
264
+ # Merge the dictionaries for easy lookup
265
+ self .stoi .update (self .byte_stoi )
266
+ self .itos .update (self .byte_itos )
267
+
268
+ # Save meta information
269
+ meta = {
270
+ "vocab_size" : self .vocab_size ,
271
+ "tokenizer" : "custom_char_with_byte_fallback" ,
272
+ "custom_chars" : self .custom_chars ,
273
+ "stoi" : self .stoi ,
274
+ "itos" : self .itos ,
275
+ "custom_char_count" : self .custom_char_count ,
276
+ }
277
+ self .save_meta (meta )
278
+
279
+ def tokenize (self , data ):
280
+ ids = []
281
+ data_len = len (data )
282
+ pbar = tqdm (total = data_len , desc = "Tokenizing with Byte Fallback" )
283
+ for ch in data :
284
+ if ch in self .stoi :
285
+ ids .append (self .stoi [ch ])
286
+ else :
287
+ # Byte fallback
288
+ byte_sequence = ch .encode ('utf-8' )
289
+ for byte in byte_sequence :
290
+ ids .append (self .stoi [byte ])
291
+ pbar .update (1 )
292
+ pbar .close ()
293
+ return ids
294
+
295
+ def detokenize (self , ids ):
296
+ chars = []
297
+ byte_buffer = []
298
+ for id in ids :
299
+ if id < self .custom_char_count :
300
+ # It's a custom character
301
+ chars .append (self .itos [id ])
302
+ else :
303
+ # It's a byte
304
+ byte_buffer .append (self .itos [id ])
305
+ # Check if the next token is not a byte or if it's the last token
306
+ if (len (byte_buffer ) > 0 and
307
+ (len (chars ) + len (byte_buffer ) == len (ids ) or
308
+ ids [ids .index (id ) + 1 ] < self .custom_char_count )):
309
+ # Convert byte buffer to character
310
+ byte_array = bytes (byte_buffer )
311
+ chars .append (byte_array .decode ('utf-8' , errors = 'replace' ))
312
+ byte_buffer = []
313
+ return '' .join (chars )
314
+
315
+ class GemmaTokenizer (Tokenizer ):
316
+ def __init__ (self , args ):
317
+ """
318
+ Initialize the Qwen2Tokenizer using Hugging Face's AutoTokenizer.
319
+ """
320
+ super ().__init__ (args )
321
+ self .huggingface_model_name = f"google/{ args .gemma_model } "
322
+ self .tokenizer = AutoTokenizer .from_pretrained (self .huggingface_model_name )
323
+
324
+ # Save vocab size and other meta information
325
+ self .vocab_size = self .tokenizer .vocab_size
326
+ self .special_tokens = self .tokenizer .special_tokens_map
327
+
328
+ def tokenize (self , data ):
329
+ print (f"Tokenizing data of size: { len (data )} " )
330
+ chunk_size = 1024
331
+ ids = []
332
+ for i in range (0 , len (data ), chunk_size ):
333
+ chunk = data [i :i + chunk_size ]
334
+ ids .extend (self .tokenizer .encode (chunk , add_special_tokens = True ))
335
+ print (f"Generated { len (ids )} token IDs." )
336
+ meta = {
337
+ "vocab_size" : self .vocab_size ,
338
+ "tokenizer" : "gemma" ,
339
+ "gemma_model" : self .huggingface_model_name ,
340
+ "special_tokens" : self .special_tokens ,
341
+ }
342
+ self .save_meta (meta )
343
+ return ids
344
+
345
+ def detokenize (self , ids ):
346
+ """
347
+ Detokenize token IDs into a string.
348
+ Args:
349
+ ids (List[int]): List of token IDs to convert back to text.
350
+ Returns:
351
+ str: Decoded string.
352
+ """
353
+ return self .tokenizer .decode (ids , skip_special_tokens = True )
0 commit comments