Skip to content

Commit

Permalink
Merge branch 'nerdstash'
Browse files Browse the repository at this point in the history
  • Loading branch information
Aedial committed Apr 20, 2023
2 parents 058abe0 + f811f30 commit ed1cecd
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 1 deletion.
10 changes: 9 additions & 1 deletion novelai_api/Tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import List, Union

import sentencepiece
import tokenizers

from novelai_api.ImagePreset import ImageModel
Expand Down Expand Up @@ -47,11 +48,18 @@ def get_tokenizer_name(cls, model: Model) -> str:
# TODO: check differences from NAI tokenizer (from my limited testing, there is None)
_CLIP_TOKENIZER = SimpleTokenizer()

_NERDSTASH_PATH = tokenizers_path / "nerdstash_tokenizer.model"
_NERDSTASH_TOKENIZER = sentencepiece.SentencePieceProcessor()
_NERDSTASH_TOKENIZER.LoadFromFile(str(_NERDSTASH_PATH))
_NERDSTASH_TOKENIZER.encode = _NERDSTASH_TOKENIZER.EncodeAsIds
_NERDSTASH_TOKENIZER.decode = _NERDSTASH_TOKENIZER.DecodeIds

_tokenizers = {
"gpt2": _GPT2_TOKENIZER,
"gpt2-genji": _GENJI_TOKENIZER,
"pile": _PILE_TOKENIZER,
"clip": _CLIP_TOKENIZER,
"nerdstash": _NERDSTASH_TOKENIZER,
}

@classmethod
Expand All @@ -69,7 +77,7 @@ def encode(cls, model: AnyModel, o: str) -> List[int]:
if isinstance(tokenizer, tokenizers.Tokenizer):
return tokenizer.encode(o).ids

if isinstance(tokenizer, SimpleTokenizer):
if isinstance(tokenizer, (sentencepiece.SentencePieceProcessor, SimpleTokenizer)):
return tokenizer.encode(o)

raise ValueError(f"Tokenizer {tokenizer} ({tokenizer_name}) not recognized")
Binary file added novelai_api/tokenizers/nerdstash_tokenizer.model
Binary file not shown.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jsonschema = "^4.17.0"
tokenizers = "^0.13.1"
ftfy = "^6.1.1"
regex = "^2022.10.31"
sentencepiece = "^0.1.98"

[tool.poetry.group.dev.dependencies]
pytest-asyncio = "^0.20.1"
Expand Down

0 comments on commit ed1cecd

Please sign in to comment.