diff --git a/novelai_api/Tokenizer.py b/novelai_api/Tokenizer.py index 43b3bba..314eb39 100644 --- a/novelai_api/Tokenizer.py +++ b/novelai_api/Tokenizer.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import List, Union +import sentencepiece import tokenizers from novelai_api.ImagePreset import ImageModel @@ -53,11 +54,23 @@ def get_tokenizer_name(cls, model: Model) -> str: # TODO: check differences from NAI tokenizer (from my limited testing, there is None) _CLIP_TOKENIZER = SimpleTokenizer() + _NERDSTASH_TOKENIZER_v1 = sentencepiece.SentencePieceProcessor() + _NERDSTASH_TOKENIZER_v1.Load(str(tokenizers_path / "nerdstash_v1.model")) + _NERDSTASH_TOKENIZER_v1.encode = _NERDSTASH_TOKENIZER_v1.EncodeAsIds + _NERDSTASH_TOKENIZER_v1.decode = _NERDSTASH_TOKENIZER_v1.DecodeIds + + _NERDSTASH_TOKENIZER_v2 = sentencepiece.SentencePieceProcessor() + _NERDSTASH_TOKENIZER_v2.Load(str(tokenizers_path / "nerdstash_v1.model")) + _NERDSTASH_TOKENIZER_v2.encode = _NERDSTASH_TOKENIZER_v2.EncodeAsIds + _NERDSTASH_TOKENIZER_v2.decode = _NERDSTASH_TOKENIZER_v2.DecodeIds + _tokenizers = { "gpt2": _GPT2_TOKENIZER, "gpt2-genji": _GENJI_TOKENIZER, "pile": _PILE_TOKENIZER, "clip": _CLIP_TOKENIZER, + "nerdstash_v1": _NERDSTASH_TOKENIZER_v1, + "nerdstash_v2": _NERDSTASH_TOKENIZER_v2, } @classmethod @@ -93,7 +106,7 @@ def encode(cls, model: AnyModel, o: str) -> List[int]: if isinstance(tokenizer, tokenizers.Tokenizer): return tokenizer.encode(o).ids - if isinstance(tokenizer, SimpleTokenizer): + if isinstance(tokenizer, (SimpleTokenizer, sentencepiece.SentencePieceProcessor)): return tokenizer.encode(o) raise ValueError(f"Tokenizer {tokenizer} ({tokenizer_name}) not recognized") diff --git a/novelai_api/_high_level.py b/novelai_api/_high_level.py index d0fa199..9641f9a 100644 --- a/novelai_api/_high_level.py +++ b/novelai_api/_high_level.py @@ -115,7 +115,7 @@ async def set_keystore(self, keystore: Keystore, key: bytes) -> bytes: return await self._parent.low_level.set_keystore(keystore.data) - async def download_user_stories(self) -> Dict[str, Dict[str, Union[str, int]]]: + async def download_user_stories(self) -> List[Dict[str, Dict[str, Union[str, int]]]]: """ Download all the objects of type 'stories' stored on the account """ @@ -126,7 +126,7 @@ async def download_user_stories(self) -> Dict[str, Dict[str, Union[str, int]]]: async def download_user_story_contents( self, - ) -> Dict[str, Dict[str, Union[str, int]]]: + ) -> List[Dict[str, Dict[str, Union[str, int]]]]: """ Download all the objects of type 'storycontent' stored on the account """ diff --git a/novelai_api/tokenizers/nerdstash_tokenizer.model b/novelai_api/tokenizers/nerdstash_v1.model similarity index 100% rename from novelai_api/tokenizers/nerdstash_tokenizer.model rename to novelai_api/tokenizers/nerdstash_v1.model diff --git a/novelai_api/tokenizers/nerdstash_v2.model b/novelai_api/tokenizers/nerdstash_v2.model new file mode 100644 index 0000000..ec24531 Binary files /dev/null and b/novelai_api/tokenizers/nerdstash_v2.model differ