diff --git a/novelai_api/Tokenizer.py b/novelai_api/Tokenizer.py index f153e21..7905a33 100644 --- a/novelai_api/Tokenizer.py +++ b/novelai_api/Tokenizer.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import List, Union +import sentencepiece import tokenizers from novelai_api.ImagePreset import ImageModel @@ -47,11 +48,18 @@ def get_tokenizer_name(cls, model: Model) -> str: # TODO: check differences from NAI tokenizer (from my limited testing, there is None) _CLIP_TOKENIZER = SimpleTokenizer() + _NERDSTASH_PATH = tokenizers_path / "nerdstash_tokenizer.model" + _NERDSTASH_TOKENIZER = sentencepiece.SentencePieceProcessor() + _NERDSTASH_TOKENIZER.LoadFromFile(str(_NERDSTASH_PATH)) + _NERDSTASH_TOKENIZER.encode = _NERDSTASH_TOKENIZER.EncodeAsIds + _NERDSTASH_TOKENIZER.decode = _NERDSTASH_TOKENIZER.DecodeIds + _tokenizers = { "gpt2": _GPT2_TOKENIZER, "gpt2-genji": _GENJI_TOKENIZER, "pile": _PILE_TOKENIZER, "clip": _CLIP_TOKENIZER, + "nerdstash": _NERDSTASH_TOKENIZER, } @classmethod @@ -69,7 +77,7 @@ def encode(cls, model: AnyModel, o: str) -> List[int]: if isinstance(tokenizer, tokenizers.Tokenizer): return tokenizer.encode(o).ids - if isinstance(tokenizer, SimpleTokenizer): + if isinstance(tokenizer, (sentencepiece.SentencePieceProcessor, SimpleTokenizer)): return tokenizer.encode(o) raise ValueError(f"Tokenizer {tokenizer} ({tokenizer_name}) not recognized") diff --git a/novelai_api/tokenizers/nerdstash_tokenizer.model b/novelai_api/tokenizers/nerdstash_tokenizer.model new file mode 100644 index 0000000..b95958a Binary files /dev/null and b/novelai_api/tokenizers/nerdstash_tokenizer.model differ diff --git a/pyproject.toml b/pyproject.toml index de2d557..f907648 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ jsonschema = "^4.17.0" tokenizers = "^0.13.1" ftfy = "^6.1.1" regex = "^2022.10.31" +sentencepiece = "^0.1.98" [tool.poetry.group.dev.dependencies] pytest-asyncio = "^0.20.1"