Skip to content

Commit efafe03

Browse files
committed
added tokenizer file
1 parent 153a8df commit efafe03

File tree

1 file changed

+353
-0
lines changed

1 file changed

+353
-0
lines changed

data/template/tokenizer_options.py

+353
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
import os
2+
import pickle
3+
import tempfile
4+
import numpy as np
5+
import sentencepiece as spm
6+
import tiktoken
7+
from tqdm import tqdm # For progress bars
8+
from transformers import AutoTokenizer
9+
10+
class Tokenizer:
11+
def __init__(self, args):
12+
self.args = args
13+
14+
def tokenize(self, data):
15+
raise NotImplementedError("Tokenize method must be implemented by subclasses.")
16+
17+
def detokenize(self, ids):
18+
raise NotImplementedError("Detokenize method must be implemented by subclasses.")
19+
20+
def save_meta(self, meta):
21+
with open("meta.pkl", "wb") as f:
22+
pickle.dump(meta, f)
23+
24+
@staticmethod
25+
def get_key_from_meta(keyname):
26+
meta_path = 'meta.pkl'
27+
if os.path.exists(meta_path):
28+
with open(meta_path, 'rb') as f:
29+
meta = pickle.load(f)
30+
return meta.get(keyname)
31+
return None
32+
33+
34+
class NumericRangeTokenizer(Tokenizer):
35+
def __init__(self, args):
36+
super().__init__(args)
37+
self.min_token = args.min_token
38+
self.max_token = args.max_token
39+
self.stoi = None
40+
self.itos = None
41+
42+
def tokenize(self, data):
43+
tokens = []
44+
encountered_tokens = set()
45+
lines = data.strip().split('\n')
46+
for line in tqdm(lines, desc="Tokenizing Numeric Range"):
47+
try:
48+
num = int(line)
49+
if self.min_token <= num <= self.max_token:
50+
tokens.append(num)
51+
encountered_tokens.add(num)
52+
else:
53+
print(f"Warning: Number {num} is outside the specified range and will be skipped.")
54+
except ValueError:
55+
print(f"Warning: Invalid number '{line}' will be skipped.")
56+
57+
all_tokens = list(range(self.max_token, -1, -1))
58+
self.stoi = {str(num): i for i, num in enumerate(all_tokens)}
59+
self.itos = {i: str(num) for i, num in enumerate(all_tokens)}
60+
61+
indexed_tokens = [self.stoi[str(token)] for token in tokens]
62+
meta = {
63+
"vocab_size": len(self.stoi),
64+
"tokenizer": "numeric_range",
65+
"min_token": self.min_token,
66+
"max_token": self.max_token,
67+
"stoi": self.stoi,
68+
"itos": self.itos,
69+
"encountered_tokens": sorted(encountered_tokens, reverse=True)
70+
}
71+
self.save_meta(meta)
72+
return indexed_tokens
73+
74+
def detokenize(self, ids):
75+
return '\n'.join([self.itos[id] for id in ids])
76+
77+
78+
class SentencePieceTokenizer(Tokenizer):
79+
def __init__(self, args, input_files=None):
80+
super().__init__(args)
81+
self.vocab_size = args.vocab_size
82+
self.spm_model_file = args.spm_model_file
83+
self.spm_vocab_file = args.spm_vocab_file
84+
self.skip_tokenization = args.skip_tokenization
85+
self.input_files = input_files
86+
self.sp = None
87+
88+
if self.spm_model_file:
89+
self.sp = spm.SentencePieceProcessor()
90+
self.sp.load(self.spm_model_file)
91+
elif input_files:
92+
self.sp = self.train_sentencepiece_model()
93+
94+
def train_sentencepiece_model(self):
95+
spm_model_prefix = "trained_spm_model"
96+
num_threads = os.cpu_count()
97+
input_arg = ""
98+
if isinstance(self.input_files, list):
99+
with tempfile.NamedTemporaryFile(delete=False, mode="w") as tmpfile:
100+
for input_file in self.input_files:
101+
with open(input_file, "r") as infile:
102+
tmpfile.write(infile.read())
103+
input_arg = tmpfile.name
104+
else:
105+
input_arg = self.input_files
106+
107+
spm.SentencePieceTrainer.train(
108+
num_threads=num_threads,
109+
user_defined_symbols="\n, ",
110+
input=input_arg,
111+
model_prefix=spm_model_prefix,
112+
split_digits=True,
113+
vocab_size=self.vocab_size,
114+
model_type="bpe",
115+
)
116+
print("SentencePiece model training complete.")
117+
118+
if isinstance(self.input_files, list):
119+
os.remove(input_arg)
120+
121+
sp = spm.SentencePieceProcessor()
122+
sp.load(f"{spm_model_prefix}.model")
123+
return sp
124+
125+
def tokenize(self, data):
126+
if not self.sp:
127+
raise ValueError("SentencePiece model is not loaded.")
128+
ids = self.sp.encode_as_ids(data)
129+
stoi = {self.sp.id_to_piece(id): id for id in range(self.sp.GetPieceSize())}
130+
itos = {id: self.sp.id_to_piece(id) for id in range(self.sp.GetPieceSize())}
131+
132+
meta = {
133+
"vocab_size": self.sp.GetPieceSize(),
134+
"tokenizer": "sentencepiece",
135+
"stoi": stoi,
136+
"itos": itos,
137+
}
138+
self.save_meta(meta)
139+
return ids
140+
141+
def detokenize(self, ids):
142+
if not self.sp:
143+
raise ValueError("SentencePiece model is not loaded.")
144+
return self.sp.decode_ids(ids)
145+
146+
147+
class TiktokenTokenizer(Tokenizer):
148+
def __init__(self, args):
149+
super().__init__(args)
150+
self.tiktoken_encoding = args.tiktoken_encoding
151+
self.enc = tiktoken.get_encoding(self.tiktoken_encoding)
152+
self.vocab_size = self.enc.n_vocab
153+
154+
def tokenize(self, data):
155+
ids = self.enc.encode_ordinary(data)
156+
meta = {
157+
"vocab_size": self.vocab_size,
158+
"tokenizer": "tiktoken",
159+
"tiktoken_encoding": self.tiktoken_encoding,
160+
}
161+
self.save_meta(meta)
162+
return ids
163+
164+
def detokenize(self, ids):
165+
return self.enc.decode(ids)
166+
167+
168+
class CustomTokenizer(Tokenizer):
169+
def __init__(self, args):
170+
super().__init__(args)
171+
if args.tokens_file is None:
172+
raise ValueError("Tokens file must be provided for custom tokenization method.")
173+
with open(args.tokens_file, "r") as f:
174+
self.tokens = [line.strip() for line in f.readlines() if line.strip()]
175+
self.tokens = [token.replace("\\n", "\n").replace("\\t", "\t") for token in self.tokens]
176+
self.stoi = {token: i for i, token in enumerate(self.tokens)}
177+
self.itos = {i: token for i, token in enumerate(self.tokens)}
178+
179+
def tokenize(self, data):
180+
encoded_data = []
181+
i = 0
182+
covered_chars = 0
183+
data_len = len(data)
184+
pbar = tqdm(total=data_len, desc="Tokenizing Custom Tokens")
185+
while i < data_len:
186+
matched = False
187+
for token in self.tokens:
188+
token_len = len(token)
189+
if data.startswith(token, i):
190+
encoded_data.append(self.stoi[token])
191+
i += token_len
192+
covered_chars += token_len
193+
pbar.update(token_len)
194+
matched = True
195+
break
196+
if not matched:
197+
i += 1 # Skip character if no token matches
198+
pbar.update(1)
199+
pbar.close()
200+
coverage = covered_chars / data_len
201+
print(f"Data coverage by tokens: {coverage*100:.2f}%")
202+
meta = {"vocab_size": len(self.tokens), "stoi": self.stoi, "itos": self.itos}
203+
self.save_meta(meta)
204+
return encoded_data
205+
206+
def detokenize(self, ids):
207+
return ''.join([self.itos[id] for id in ids])
208+
209+
class CharTokenizer(Tokenizer):
210+
def __init__(self, args, train_data, val_data):
211+
super().__init__(args)
212+
self.reuse_chars = args.reuse_chars
213+
if self.reuse_chars:
214+
self.chars = self.get_key_from_meta('chars')
215+
if self.chars is None:
216+
raise ValueError("No chars found in meta.pkl. Cannot reuse chars.")
217+
else:
218+
self.chars = sorted(list(set(train_data + (val_data if val_data else ""))))
219+
print(f"All unique characters: {''.join(self.chars)}")
220+
print(f"Vocab size: {len(self.chars)}")
221+
self.stoi = {ch: i for i, ch in enumerate(self.chars)}
222+
self.itos = {i: ch for i, ch in enumerate(self.chars)}
223+
224+
def tokenize(self, data):
225+
data_len = len(data)
226+
ids = []
227+
pbar = tqdm(total=data_len, desc="Tokenizing Characters")
228+
for ch in data:
229+
ids.append(self.stoi[ch])
230+
pbar.update(1)
231+
pbar.close()
232+
meta = {"vocab_size": len(self.chars), "itos": self.itos, "stoi": self.stoi, "chars": self.chars}
233+
self.save_meta(meta)
234+
return ids
235+
236+
def detokenize(self, ids):
237+
return ''.join([self.itos[id] for id in ids])
238+
239+
240+
class CustomCharTokenizerWithByteFallback(Tokenizer):
241+
def __init__(self, args):
242+
super().__init__(args)
243+
if args.custom_chars_file is None:
244+
raise ValueError("Custom characters file must be provided for this tokenizer.")
245+
with open(args.custom_chars_file, "r", encoding="utf-8") as f:
246+
self.custom_chars = [line.strip() for line in f if line.strip()]
247+
248+
# Build vocab
249+
self.build_vocab()
250+
251+
def build_vocab(self):
252+
# Assign IDs to custom characters
253+
self.stoi = {ch: i for i, ch in enumerate(self.custom_chars)}
254+
self.itos = {i: ch for i, ch in enumerate(self.custom_chars)}
255+
self.custom_char_count = len(self.custom_chars)
256+
257+
# Assign IDs to bytes (0-255)
258+
self.byte_stoi = {byte: i + self.custom_char_count for i, byte in enumerate(range(256))}
259+
self.byte_itos = {i + self.custom_char_count: byte for i, byte in enumerate(range(256))}
260+
261+
# Update total vocab size
262+
self.vocab_size = self.custom_char_count + 256 # 256 bytes
263+
264+
# Merge the dictionaries for easy lookup
265+
self.stoi.update(self.byte_stoi)
266+
self.itos.update(self.byte_itos)
267+
268+
# Save meta information
269+
meta = {
270+
"vocab_size": self.vocab_size,
271+
"tokenizer": "custom_char_with_byte_fallback",
272+
"custom_chars": self.custom_chars,
273+
"stoi": self.stoi,
274+
"itos": self.itos,
275+
"custom_char_count": self.custom_char_count,
276+
}
277+
self.save_meta(meta)
278+
279+
def tokenize(self, data):
280+
ids = []
281+
data_len = len(data)
282+
pbar = tqdm(total=data_len, desc="Tokenizing with Byte Fallback")
283+
for ch in data:
284+
if ch in self.stoi:
285+
ids.append(self.stoi[ch])
286+
else:
287+
# Byte fallback
288+
byte_sequence = ch.encode('utf-8')
289+
for byte in byte_sequence:
290+
ids.append(self.stoi[byte])
291+
pbar.update(1)
292+
pbar.close()
293+
return ids
294+
295+
def detokenize(self, ids):
296+
chars = []
297+
byte_buffer = []
298+
for id in ids:
299+
if id < self.custom_char_count:
300+
# It's a custom character
301+
chars.append(self.itos[id])
302+
else:
303+
# It's a byte
304+
byte_buffer.append(self.itos[id])
305+
# Check if the next token is not a byte or if it's the last token
306+
if (len(byte_buffer) > 0 and
307+
(len(chars) + len(byte_buffer) == len(ids) or
308+
ids[ids.index(id) + 1] < self.custom_char_count)):
309+
# Convert byte buffer to character
310+
byte_array = bytes(byte_buffer)
311+
chars.append(byte_array.decode('utf-8', errors='replace'))
312+
byte_buffer = []
313+
return ''.join(chars)
314+
315+
class GemmaTokenizer(Tokenizer):
316+
def __init__(self, args):
317+
"""
318+
Initialize the Qwen2Tokenizer using Hugging Face's AutoTokenizer.
319+
"""
320+
super().__init__(args)
321+
self.huggingface_model_name = f"google/{args.gemma_model}"
322+
self.tokenizer = AutoTokenizer.from_pretrained(self.huggingface_model_name)
323+
324+
# Save vocab size and other meta information
325+
self.vocab_size = self.tokenizer.vocab_size
326+
self.special_tokens = self.tokenizer.special_tokens_map
327+
328+
def tokenize(self, data):
329+
print(f"Tokenizing data of size: {len(data)}")
330+
chunk_size = 1024
331+
ids = []
332+
for i in range(0, len(data), chunk_size):
333+
chunk = data[i:i + chunk_size]
334+
ids.extend(self.tokenizer.encode(chunk, add_special_tokens=True))
335+
print(f"Generated {len(ids)} token IDs.")
336+
meta = {
337+
"vocab_size": self.vocab_size,
338+
"tokenizer": "gemma",
339+
"gemma_model": self.huggingface_model_name,
340+
"special_tokens": self.special_tokens,
341+
}
342+
self.save_meta(meta)
343+
return ids
344+
345+
def detokenize(self, ids):
346+
"""
347+
Detokenize token IDs into a string.
348+
Args:
349+
ids (List[int]): List of token IDs to convert back to text.
350+
Returns:
351+
str: Decoded string.
352+
"""
353+
return self.tokenizer.decode(ids, skip_special_tokens=True)

0 commit comments

Comments
 (0)