Skip to content

Commit

Permalink
[API] Add NerdStash tokenizer v2
Browse files Browse the repository at this point in the history
  • Loading branch information
Aedial committed May 20, 2023
1 parent 99a571c commit 38dbf93
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
15 changes: 14 additions & 1 deletion novelai_api/Tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import List, Union

import sentencepiece
import tokenizers

from novelai_api.ImagePreset import ImageModel
Expand Down Expand Up @@ -53,11 +54,23 @@ def get_tokenizer_name(cls, model: Model) -> str:
# TODO: check differences from NAI tokenizer (from my limited testing, there is None)
_CLIP_TOKENIZER = SimpleTokenizer()

_NERDSTASH_TOKENIZER_v1 = sentencepiece.SentencePieceProcessor()
_NERDSTASH_TOKENIZER_v1.Load(str(tokenizers_path / "nerdstash_v1.model"))
_NERDSTASH_TOKENIZER_v1.encode = _NERDSTASH_TOKENIZER_v1.EncodeAsIds
_NERDSTASH_TOKENIZER_v1.decode = _NERDSTASH_TOKENIZER_v1.DecodeIds

_NERDSTASH_TOKENIZER_v2 = sentencepiece.SentencePieceProcessor()
_NERDSTASH_TOKENIZER_v2.Load(str(tokenizers_path / "nerdstash_v1.model"))
_NERDSTASH_TOKENIZER_v2.encode = _NERDSTASH_TOKENIZER_v2.EncodeAsIds
_NERDSTASH_TOKENIZER_v2.decode = _NERDSTASH_TOKENIZER_v2.DecodeIds

_tokenizers = {
"gpt2": _GPT2_TOKENIZER,
"gpt2-genji": _GENJI_TOKENIZER,
"pile": _PILE_TOKENIZER,
"clip": _CLIP_TOKENIZER,
"nerdstash_v1": _NERDSTASH_TOKENIZER_v1,
"nerdstash_v2": _NERDSTASH_TOKENIZER_v2,
}

@classmethod
Expand Down Expand Up @@ -93,7 +106,7 @@ def encode(cls, model: AnyModel, o: str) -> List[int]:
if isinstance(tokenizer, tokenizers.Tokenizer):
return tokenizer.encode(o).ids

if isinstance(tokenizer, SimpleTokenizer):
if isinstance(tokenizer, (SimpleTokenizer, sentencepiece.SentencePieceProcessor)):
return tokenizer.encode(o)

raise ValueError(f"Tokenizer {tokenizer} ({tokenizer_name}) not recognized")
4 changes: 2 additions & 2 deletions novelai_api/_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def set_keystore(self, keystore: Keystore, key: bytes) -> bytes:

return await self._parent.low_level.set_keystore(keystore.data)

async def download_user_stories(self) -> Dict[str, Dict[str, Union[str, int]]]:
async def download_user_stories(self) -> List[Dict[str, Dict[str, Union[str, int]]]]:
"""
Download all the objects of type 'stories' stored on the account
"""
Expand All @@ -126,7 +126,7 @@ async def download_user_stories(self) -> Dict[str, Dict[str, Union[str, int]]]:

async def download_user_story_contents(
self,
) -> Dict[str, Dict[str, Union[str, int]]]:
) -> List[Dict[str, Dict[str, Union[str, int]]]]:
"""
Download all the objects of type 'storycontent' stored on the account
"""
Expand Down
File renamed without changes.
Binary file added novelai_api/tokenizers/nerdstash_v2.model
Binary file not shown.

0 comments on commit 38dbf93

Please sign in to comment.