Skip to content

Commit

Permalink
[API] Add document unpacking for story content
Browse files Browse the repository at this point in the history
  • Loading branch information
Aedial committed Jun 1, 2024
1 parent 558ab34 commit 9ab194a
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 10 deletions.
6 changes: 6 additions & 0 deletions example/boilerplate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
from datetime import datetime
from logging import Logger, StreamHandler
from os import environ as env
from pathlib import Path
from typing import Any, Optional

from aiohttp import ClientSession
from msgpackr.constants import UNDEFINED

from novelai_api import NovelAIAPI
from novelai_api.utils import get_encryption_key
Expand Down Expand Up @@ -85,6 +87,10 @@ class JSONEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
if isinstance(o, bytes):
return o.hex()
if o is UNDEFINED:
return "<UNDEFINED>"
if isinstance(o, datetime):
return o.isoformat()

return super().default(o)

Expand Down
30 changes: 30 additions & 0 deletions example/download_last_story_and_decrypt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio
from pathlib import Path

from example.boilerplate import API, dumps
from novelai_api.utils import decrypt_user_data

dump_file = Path(__file__).parent.parent / "results" / "story.json"


async def main():
async with API() as api_handler:
api = api_handler.api
key = api_handler.encryption_key

keystore = await api.high_level.get_keystore(key)

story = (await api.high_level.download_user_stories())[0]
decrypt_user_data(story, keystore)

storycontent_id = story["data"]["remoteStoryId"]

story_contents = await api.low_level.download_object("storycontent", storycontent_id)
decrypt_user_data(story_contents, keystore, True)

dump_file.parent.mkdir(exist_ok=True)
dump_file.write_text(dumps(story_contents), "utf-8")


if __name__ == "__main__":
asyncio.run(main())
8 changes: 0 additions & 8 deletions novelai_api/GlobalSettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,14 +712,6 @@ class GlobalSettings:
".",
":",
"\n",
"ve",
"s",
"t",
"n",
"d",
"ll",
"re",
"m",
"-",
"*",
")",
Expand Down
40 changes: 40 additions & 0 deletions novelai_api/Msgpackr_Extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any

from msgpackr.constants import SKIP
from msgpackr.unpack import MsgpackExtension


class NAIExtension(MsgpackExtension):
@classmethod
def unpack(cls, _unpacker, _data: bytes, _pos: int, _length: int) -> Any:
return SKIP

# TODO: the data should be bundled in post_unpack

@classmethod
def pack(cls, _unpacker, data: Any) -> bytes:
return data


class Ext20(NAIExtension):
EXT_TYPE = 20


class Ext30(NAIExtension):
EXT_TYPE = 30


class Ext31(NAIExtension):
EXT_TYPE = 31


class Ext40(NAIExtension):
EXT_TYPE = 40


class Ext41(NAIExtension):
EXT_TYPE = 41


class Ext42(NAIExtension):
EXT_TYPE = 42
16 changes: 15 additions & 1 deletion novelai_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@
from zlib import decompress as inflate

import argon2
from msgpackr import Unpacker
from nacl.exceptions import CryptoError
from nacl.secret import SecretBox

from novelai_api.Keystore import Keystore
from novelai_api.Msgpackr_Extensions import Ext20, Ext30, Ext31, Ext40, Ext41, Ext42
from novelai_api.NovelAIError import NovelAIError
from novelai_api.Preset import Model, Preset
from novelai_api.python_utils import assert_type
from novelai_api.Tokenizer import Tokenizer

unpacker = Unpacker()
unpacker.register_extensions(Ext20, Ext30, Ext31, Ext40, Ext41, Ext42)
unpacker_state = unpacker.export_state()


# API utils
def argon_hash(email: str, password: str, size: int, domain: str) -> str:
Expand Down Expand Up @@ -185,13 +191,16 @@ def compress_user_data(items: Union[List[Dict[str, Any]], Dict[str, Any]]):
del item["decrypted"]


def decrypt_user_data(items: Union[List[Dict[str, Any]], Dict[str, Any]], keystore: Keystore):
def decrypt_user_data(
items: Union[List[Dict[str, Any]], Dict[str, Any]], keystore: Keystore, uncompress_document: bool = False
):
"""
Decrypt the data of each item in :ref: items
If an item has already been decrypted, it won't be decrypted a second time
:param items: Item or list of items to decrypt
:param keystore: Keystore retrieved with the get_keystore method
:param uncompress_document: If True, the document will be decompressed
"""

# 1 item
Expand Down Expand Up @@ -224,6 +233,11 @@ def decrypt_user_data(items: Union[List[Dict[str, Any]], Dict[str, Any]], keysto
item["nonce"] = nonce
item["decrypted"] = True
item["compressed"] = is_compressed

if uncompress_document and "document" in data:
unpacker.restore_state(unpacker_state)
data["document"] = unpacker.unpack(b64decode(data["document"]))

continue

except json.JSONDecodeError:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ftfy = "^6.1.3"
regex = "^2023.12.25"
sentencepiece = "^0.2.0"
poetry = "^1.7.1"
# msgpack = "^1.0.5"
msgpackr-python = "^0.1.2"

[tool.poetry.group.dev.dependencies]
python-dotenv = "^0.21.1"
Expand Down

0 comments on commit 9ab194a

Please sign in to comment.