-
Notifications
You must be signed in to change notification settings - Fork 1
/
tokenizer.py
executable file
·45 lines (37 loc) · 1.21 KB
/
tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
from typing import Any
import torch
from llama.tokenizer import Tokenizer as LlamaTokenizer
from transformers import AutoTokenizer
class Tokenizer:
def __init__(
self,
config,
):
self.config = config
if self.config.model == "codellama":
self.tokenizer = LlamaTokenizer(
os.path.expanduser(self.config.tokenizer_path)
)
self.pad_id = self.tokenizer.eos_id
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model)
self.tokenizer.pad_token = "<fim_pad>"
self.pad_id = self.tokenizer.pad_token
def encode(self, data: Any) -> Any:
if self.config.model == "codellama":
return self.tokenizer.encode(
data,
bos=True,
eos=False,
)
else:
return self.tokenizer.encode(
data,
return_tensors="pt",
padding="max_length",
max_length=self.config.max_context_len,
truncation=True,
)
def decode(self, tokenized_data: Any) -> str:
return self.tokenizer.decode(tokenized_data)