diff --git a/README.md b/README.md index 6073d27..21ed475 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,60 @@ -# novelai-api -Python API for the NovelAI REST API - -This module is intended to be used by developers as a helper for using NovelAI's REST API. - -[TODO]: # (Add Quality Checking workflows and badges) - -| Category | Badges | -|------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Pypi | [![PyPI](https://img.shields.io/pypi/v/novelai-api)](https://pypi.org/project/novelai-api) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/novelai-api)](https://pypi.org/project/novelai-api) [![PyPI - License](https://img.shields.io/pypi/l/novelai-api)](https://pypi.org/project/novelai-api/) [![PyPI - Format](https://img.shields.io/pypi/format/novelai-api)](https://pypi.org/project/novelai-api/) | -| Quality checking | [![Python package](https://github.com/Aedial/novelai-api/actions/workflows/python-package.yml/badge.svg)](https://github.com/Aedial/novelai-api/actions/workflows/python-package.yml) [![Python package](https://github.com/Aedial/novelai-api/actions/workflows/codeql-analysis.yml/badge.svg)](https://github.com/Aedial/novelai-api/actions/workflows/codeql-analysis.yml) [![linting: pylint](https://img.shields.io/badge/linting-pylint-yellowgreen)](https://github.com/PyCQA/pylint) [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) | -| Stats | [![GitHub top language](https://img.shields.io/github/languages/top/Aedial/novelai-api)](https://github.com/Aedial/novelai-api/search?l=python) ![Libraries.io dependency status for GitHub repo](https://img.shields.io/librariesio/github/Aedial/novelai-api) ![GitHub repo size](https://img.shields.io/github/repo-size/Aedial/novelai-api) ![GitHub issues](https://img.shields.io/github/issues-raw/Aedial/novelai-api) ![GitHub pull requests](https://img.shields.io/github/issues-pr-raw/Aedial/novelai-api) | -| Activity | ![GitHub last commit](https://img.shields.io/github/last-commit/Aedial/novelai-api) ![GitHub commits since tagged version](https://img.shields.io/github/commits-since/Aedial/novelai-api/v0.11.2) ![GitHub commit activity](https://img.shields.io/github/commit-activity/m/Aedial/novelai-api) | - - -### Prerequisites -Before anything, ensure that nox is installed (pip install nox). -For logging in, credentials are needed (NAI_USERNAME and NAI_PASSWORD). They should be passed via the environment variables (dotenv file supported). - -### Examples -The examples are in the example folder. Each example is standalone and can be used as a test. -Examples should be ran with `nox -s run -- python example/.py`. - -Some tests can act as example. The full list is as follows : -- decryption and re-encryption: tests/test_decrypt_encrypt_integrity_check.py -- diverse generations: tests/test_generate.py -- parallel generations: tests/test_generate_parallel.py - -### Usage -The source and all the required functions are located in the novelai-api folder. -The examples and tests showcase how this API should be used and can be regarded as the "right way" to use it. However, it doesn't mean one can't use the "low level" part, which is a thin implementation of the REST endpoints, while the "high level" part is an abstraction built on that low level. - -### Contributing -You can contribute features and enhancements through PR. Any PR should pass the tests and the pre-commits before submission. - -The tests against the API can be ran with `nox -s test_api`. Note that having node.js installed is required for the test to run properly. -/!\ WIP /!\ The tests against the mocked backend can be ran with `nox -s test_mock`. - -To install and run the pre-commit hook, run `nox -s pre-commit`. This hook should be installed before committing anything. +# novelai-api +Python API for the NovelAI REST API + +This module is intended to be used by developers as a helper for using NovelAI's REST API. + +[TODO]: # (Add Quality Checking workflows and badges) + +| Category | Badges | +|------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Pypi | [![PyPI](https://img.shields.io/pypi/v/novelai-api)](https://pypi.org/project/novelai-api) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/novelai-api)](https://pypi.org/project/novelai-api) [![PyPI - License](https://img.shields.io/pypi/l/novelai-api)](https://pypi.org/project/novelai-api/) [![PyPI - Format](https://img.shields.io/pypi/format/novelai-api)](https://pypi.org/project/novelai-api/) | +| Quality checking | [![Python package](https://github.com/Aedial/novelai-api/actions/workflows/python-package.yml/badge.svg)](https://github.com/Aedial/novelai-api/actions/workflows/python-package.yml) [![Python package](https://github.com/Aedial/novelai-api/actions/workflows/codeql-analysis.yml/badge.svg)](https://github.com/Aedial/novelai-api/actions/workflows/codeql-analysis.yml) [![linting: pylint](https://img.shields.io/badge/linting-pylint-yellowgreen)](https://github.com/PyCQA/pylint) [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) | +| Stats | [![GitHub top language](https://img.shields.io/github/languages/top/Aedial/novelai-api)](https://github.com/Aedial/novelai-api/search?l=python) ![Libraries.io dependency status for GitHub repo](https://img.shields.io/librariesio/github/Aedial/novelai-api) ![GitHub repo size](https://img.shields.io/github/repo-size/Aedial/novelai-api) ![GitHub issues](https://img.shields.io/github/issues-raw/Aedial/novelai-api) ![GitHub pull requests](https://img.shields.io/github/issues-pr-raw/Aedial/novelai-api) | +| Activity | ![GitHub last commit](https://img.shields.io/github/last-commit/Aedial/novelai-api) ![GitHub commits since tagged version](https://img.shields.io/github/commits-since/Aedial/novelai-api/v0.10.5) ![GitHub commit activity](https://img.shields.io/github/commit-activity/m/Aedial/novelai-api) | + + +# Usage +Download via [pip](https://pypi.org/project/novelai-api): +``` +pip install novelai-api +``` + +A full list of examples is available in the [example](/example) directory + +The API works through the NovelAIAPI object. +It is split in 2 groups: NovelAIAPI.low_level and NovelAIAPI.high_level + +## low_level +The low level interface is a strict implementation of the official API (). +It only checks for input types via assert and output schema if NovelAIAPI.low_level.is_schema_validation_enabled is True + +## high_level +The high level interface builds on the low level one for easier handling of complex settings. +It handles many tasks from the frontend + + +# Development +All relevant objects are in the [novelai_api](novelai_api) directory. The [nox](https://pypi.org/project/nox/) package is required (`pip install nox`). + +## Contributing +You can contribute features and enhancements through PR. Any PR should pass the tests and the pre-commits before submission. +The pre-commit hook can be installed via +``` +nox -s pre-commit +``` + +## Testing against the API +[API](tests/api) + +## Testing against the mocked API +| :warning: WIP, does not work yet :warning: | +|--------------------------------------------| + +[Mock](tests/mock) + +## Docs +To build the docs, run +``` +nox -s build-docs +``` +The docs will be locally viewable at docs/build/html/index.html diff --git a/docs/requirements.txt b/docs/requirements.txt index 516006e..f1b5af7 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,7 @@ sphinx==6.1.3 -myst-parser +# patched repo to work with relative links +git+https://github.com/Aedial/MyST-Parser +linkify-it-py sphinx-copybutton sphinx_last_updated_by_git sphinx-hoverxref diff --git a/docs/source/conf.py b/docs/source/conf.py index a9f612f..0a8c4ff 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -3,15 +3,22 @@ # For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html +import datetime +import inspect import os import sys +from pathlib import Path +from types import ModuleType +from typing import List + +from sphinx.application import Sphinx +from sphinx.ext.autodoc import Options # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = "NovelAI API" -# pylint: disable=W0622 -copyright = "2023, Aedial" # noqa (built-in) +copyright = f"{datetime.datetime.now().year}, Aedial" # noqa (built-in), pylint: disable=W0622 author = "Aedial" release = "0.11.6" @@ -33,9 +40,16 @@ "hoverxref.extension", ] +autodoc_class_signature = "separated" autodoc_member_order = "bysource" +autodoc_typehints_format = "fully-qualified" +autodoc_preserve_defaults = True +autodoc_inherit_docstrings = False + +extlinks = {"issue": ("https://github.com/Aedial/novelai-api/issues/%s", "[issue %s]")} -extlinks = {"issue": ("https://github.com/sphinx-doc/sphinx/issues/%s", "[issue %s]")} +myst_all_links_external = True +myst_relative_links_base = "https://github.com/Aedial/novelai-api/tree/main/" suppress_warnings = ["myst.header", "git.too_shallow"] @@ -62,3 +76,28 @@ html_theme = "classic" # no asset yet # html_static_path = ['_static'] + + +# -- Hooks ------------------------------------------------------------------- + + +def format_docstring(_app: Sphinx, what: str, name: str, obj: ModuleType, _options: Options, lines: List[str]): + kwargs = { + "obj_type": what, + "obj_name": name, + } + + try: + path = Path(inspect.getfile(obj)) + + kwargs.update(abspath=str(path.resolve()), filename=path.name, filestem=path.stem) + except TypeError: + pass + + for i, line in enumerate(lines): + if "{" in line and "}" in line: + lines[i] = line.format(**kwargs) + + +def setup(app): + app.connect("autodoc-process-docstring", format_docstring) diff --git a/docs/source/example/example.boilerplate.rst b/docs/source/example/example.boilerplate.rst new file mode 100644 index 0000000..d570961 --- /dev/null +++ b/docs/source/example/example.boilerplate.rst @@ -0,0 +1,7 @@ +boilerplate +=========== + +.. automodule:: example.boilerplate + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/example/example.rst b/docs/source/example/example.rst new file mode 100644 index 0000000..ff954f3 --- /dev/null +++ b/docs/source/example/example.rst @@ -0,0 +1,46 @@ +example directory +================= + +.. include:: ../../../example/README.md + :parser: myst_parser.sphinx_ + +Content +------- + +.. automodule:: example.download_modules + +.. automodule:: example.download_presets + +.. automodule:: example.download_shelves + +.. automodule:: example.download_stories_and_content + +.. automodule:: example.generate_controlnet_masks + +.. automodule:: example.generate_image + +.. automodule:: example.generate_image_test_samplers + +.. automodule:: example.generate_image_with_controlnet + +.. automodule:: example.generate_image_with_img2img + +.. automodule:: example.generate_text + +.. automodule:: example.generate_voice + +.. automodule:: example.login + +.. automodule:: example.login_with_proxy + +.. automodule:: example.suggest_tags + +.. automodule:: example.upscale_image + +Reference +--------- + +.. toctree:: + :maxdepth: 2 + + example.boilerplate diff --git a/docs/source/index.rst b/docs/source/index.rst index 7ac9422..cbd9ab9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,7 +14,20 @@ TODO Reference ========= + +novelai-api +----------- + +.. toctree:: + :maxdepth: 2 + + novelai_api/novelai_api + + +example +------- + .. toctree:: :maxdepth: 2 - novelai_api + example/example diff --git a/docs/source/novelai_api.python_utils.rst b/docs/source/novelai_api.python_utils.rst deleted file mode 100644 index 2a7b803..0000000 --- a/docs/source/novelai_api.python_utils.rst +++ /dev/null @@ -1,7 +0,0 @@ -novelai\_api.python\_utils module -================================= - -.. automodule:: novelai_api.python_utils - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/novelai_api.BanList.rst b/docs/source/novelai_api/novelai_api.BanList.rst similarity index 61% rename from docs/source/novelai_api.BanList.rst rename to docs/source/novelai_api/novelai_api.BanList.rst index 84f16bd..3e05c86 100644 --- a/docs/source/novelai_api.BanList.rst +++ b/docs/source/novelai_api/novelai_api.BanList.rst @@ -1,5 +1,5 @@ -novelai\_api.BanList module -=========================== +novelai\_api.BanList +==================== .. automodule:: novelai_api.BanList :members: diff --git a/docs/source/novelai_api.BiasGroup.rst b/docs/source/novelai_api/novelai_api.BiasGroup.rst similarity index 60% rename from docs/source/novelai_api.BiasGroup.rst rename to docs/source/novelai_api/novelai_api.BiasGroup.rst index add1570..f654322 100644 --- a/docs/source/novelai_api.BiasGroup.rst +++ b/docs/source/novelai_api/novelai_api.BiasGroup.rst @@ -1,5 +1,5 @@ -novelai\_api.BiasGroup module -============================= +novelai\_api.BiasGroup +====================== .. automodule:: novelai_api.BiasGroup :members: diff --git a/docs/source/novelai_api.GlobalSettings.rst b/docs/source/novelai_api/novelai_api.GlobalSettings.rst similarity index 58% rename from docs/source/novelai_api.GlobalSettings.rst rename to docs/source/novelai_api/novelai_api.GlobalSettings.rst index b6d1512..f522b6a 100644 --- a/docs/source/novelai_api.GlobalSettings.rst +++ b/docs/source/novelai_api/novelai_api.GlobalSettings.rst @@ -1,5 +1,5 @@ -novelai\_api.GlobalSettings module -================================== +novelai\_api.GlobalSettings +=========================== .. automodule:: novelai_api.GlobalSettings :members: diff --git a/docs/source/novelai_api.Idstore.rst b/docs/source/novelai_api/novelai_api.Idstore.rst similarity index 61% rename from docs/source/novelai_api.Idstore.rst rename to docs/source/novelai_api/novelai_api.Idstore.rst index d93c4ca..8a51f1a 100644 --- a/docs/source/novelai_api.Idstore.rst +++ b/docs/source/novelai_api/novelai_api.Idstore.rst @@ -1,5 +1,5 @@ -novelai\_api.Idstore module -=========================== +novelai\_api.Idstore +==================== .. automodule:: novelai_api.Idstore :members: diff --git a/docs/source/novelai_api.ImagePreset.rst b/docs/source/novelai_api/novelai_api.ImagePreset.rst similarity index 59% rename from docs/source/novelai_api.ImagePreset.rst rename to docs/source/novelai_api/novelai_api.ImagePreset.rst index 544da85..0cca697 100644 --- a/docs/source/novelai_api.ImagePreset.rst +++ b/docs/source/novelai_api/novelai_api.ImagePreset.rst @@ -1,5 +1,5 @@ -novelai\_api.ImagePreset module -=============================== +novelai\_api.ImagePreset +======================== .. automodule:: novelai_api.ImagePreset :members: diff --git a/docs/source/novelai_api.Keystore.rst b/docs/source/novelai_api/novelai_api.Keystore.rst similarity index 61% rename from docs/source/novelai_api.Keystore.rst rename to docs/source/novelai_api/novelai_api.Keystore.rst index 2c0a814..8a3546b 100644 --- a/docs/source/novelai_api.Keystore.rst +++ b/docs/source/novelai_api/novelai_api.Keystore.rst @@ -1,5 +1,5 @@ -novelai\_api.Keystore module -============================ +novelai\_api.Keystore +===================== .. automodule:: novelai_api.Keystore :members: diff --git a/docs/source/novelai_api.NovelAIError.rst b/docs/source/novelai_api/novelai_api.NovelAIError.rst similarity index 59% rename from docs/source/novelai_api.NovelAIError.rst rename to docs/source/novelai_api/novelai_api.NovelAIError.rst index 5a70b41..0241501 100644 --- a/docs/source/novelai_api.NovelAIError.rst +++ b/docs/source/novelai_api/novelai_api.NovelAIError.rst @@ -1,5 +1,5 @@ -novelai\_api.NovelAIError module -================================ +novelai\_api.NovelAIError +========================= .. automodule:: novelai_api.NovelAIError :members: diff --git a/docs/source/novelai_api.NovelAI_API.rst b/docs/source/novelai_api/novelai_api.NovelAI_API.rst similarity index 59% rename from docs/source/novelai_api.NovelAI_API.rst rename to docs/source/novelai_api/novelai_api.NovelAI_API.rst index 05de576..b9cda84 100644 --- a/docs/source/novelai_api.NovelAI_API.rst +++ b/docs/source/novelai_api/novelai_api.NovelAI_API.rst @@ -1,5 +1,5 @@ -novelai\_api.NovelAI\_API module -================================ +novelai\_api.NovelAI\_API +========================= .. automodule:: novelai_api.NovelAI_API :members: diff --git a/docs/source/novelai_api.Preset.rst b/docs/source/novelai_api/novelai_api.Preset.rst similarity index 62% rename from docs/source/novelai_api.Preset.rst rename to docs/source/novelai_api/novelai_api.Preset.rst index 610cfce..61ced36 100644 --- a/docs/source/novelai_api.Preset.rst +++ b/docs/source/novelai_api/novelai_api.Preset.rst @@ -1,5 +1,5 @@ -novelai\_api.Preset module -========================== +novelai\_api.Preset +=================== .. automodule:: novelai_api.Preset :members: diff --git a/docs/source/novelai_api.SchemaValidator.rst b/docs/source/novelai_api/novelai_api.SchemaValidator.rst similarity index 57% rename from docs/source/novelai_api.SchemaValidator.rst rename to docs/source/novelai_api/novelai_api.SchemaValidator.rst index 1fa1e47..a0f0f2f 100644 --- a/docs/source/novelai_api.SchemaValidator.rst +++ b/docs/source/novelai_api/novelai_api.SchemaValidator.rst @@ -1,5 +1,5 @@ -novelai\_api.SchemaValidator module -=================================== +novelai\_api.SchemaValidator +============================ .. automodule:: novelai_api.SchemaValidator :members: diff --git a/docs/source/novelai_api.StoryHandler.rst b/docs/source/novelai_api/novelai_api.StoryHandler.rst similarity index 59% rename from docs/source/novelai_api.StoryHandler.rst rename to docs/source/novelai_api/novelai_api.StoryHandler.rst index a59d145..69a97cc 100644 --- a/docs/source/novelai_api.StoryHandler.rst +++ b/docs/source/novelai_api/novelai_api.StoryHandler.rst @@ -1,5 +1,5 @@ -novelai\_api.StoryHandler module -================================ +novelai\_api.StoryHandler +========================= .. automodule:: novelai_api.StoryHandler :members: diff --git a/docs/source/novelai_api.Tokenizer.rst b/docs/source/novelai_api/novelai_api.Tokenizer.rst similarity index 60% rename from docs/source/novelai_api.Tokenizer.rst rename to docs/source/novelai_api/novelai_api.Tokenizer.rst index f5de5ca..8361c6b 100644 --- a/docs/source/novelai_api.Tokenizer.rst +++ b/docs/source/novelai_api/novelai_api.Tokenizer.rst @@ -1,5 +1,5 @@ -novelai\_api.Tokenizer module -============================= +novelai\_api.Tokenizer +====================== .. automodule:: novelai_api.Tokenizer :members: diff --git a/docs/source/novelai_api.rst b/docs/source/novelai_api/novelai_api.rst similarity index 83% rename from docs/source/novelai_api.rst rename to docs/source/novelai_api/novelai_api.rst index fc3b5a7..b10e05b 100644 --- a/docs/source/novelai_api.rst +++ b/docs/source/novelai_api/novelai_api.rst @@ -1,5 +1,5 @@ -novelai\_api package -==================== +novelai-api package +=================== .. toctree:: :maxdepth: 2 @@ -17,4 +17,3 @@ novelai\_api package novelai_api.StoryHandler novelai_api.Tokenizer novelai_api.utils - novelai_api.python_utils diff --git a/docs/source/novelai_api.utils.rst b/docs/source/novelai_api/novelai_api.utils.rst similarity index 63% rename from docs/source/novelai_api.utils.rst rename to docs/source/novelai_api/novelai_api.utils.rst index 4be55a4..5d406bc 100644 --- a/docs/source/novelai_api.utils.rst +++ b/docs/source/novelai_api/novelai_api.utils.rst @@ -1,5 +1,5 @@ -novelai\_api.utils module -========================= +novelai\_api.utils +================== .. automodule:: novelai_api.utils :members: diff --git a/example/README.md b/example/README.md new file mode 100644 index 0000000..5f9e641 --- /dev/null +++ b/example/README.md @@ -0,0 +1,18 @@ +### Requirements +Requires the "NAI_USERNAME" and "NAI_PASSWORD" values provided via environment variables. + +The "NAI_PROXY" environment variable is also supported to inject a proxy address. + +### Usage +If you have the novelai-api package installed via pip : +``` +python example/ +``` + +
+ +If you don't have the novelai-api package installed, or you're actively developing on the project : +``` +nox -s run -- example/ +``` +This option supports providing environment variables through a .env file diff --git a/example/boilerplate.py b/example/boilerplate.py index 4cf9678..53e21d6 100644 --- a/example/boilerplate.py +++ b/example/boilerplate.py @@ -10,6 +10,10 @@ class API: + """ + Boilerplate for the redundant parts + """ + _username: str _password: str _session: ClientSession @@ -47,6 +51,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): class JSONEncoder(json.JSONEncoder): + """ + Extended JSON encoder to support bytes + """ + def default(self, o: Any) -> Any: if isinstance(o, bytes): return o.hex() @@ -55,4 +63,8 @@ def default(self, o: Any) -> Any: def dumps(e: Any) -> str: + """ + Shortcut to a configuration of json.dumps for consistency + """ + return json.dumps(e, indent=4, ensure_ascii=False, cls=JSONEncoder) diff --git a/example/download_modules.py b/example/download_modules.py index 0c4b7fd..b8ad283 100644 --- a/example/download_modules.py +++ b/example/download_modules.py @@ -1,7 +1,13 @@ -from asyncio import run +""" +{filename} +============================================================================== -from boilerplate import API, dumps +Example of how to download and decrypt modules from the provided account +""" +import asyncio + +from example.boilerplate import API, dumps from novelai_api.utils import decrypt_user_data, encrypt_user_data @@ -22,4 +28,5 @@ async def main(): encrypt_user_data(modules, keystore) -run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/download_presets.py b/example/download_presets.py index cce7d8b..15db874 100644 --- a/example/download_presets.py +++ b/example/download_presets.py @@ -1,7 +1,13 @@ -from asyncio import run +""" +{filename} +============================================================================== -from boilerplate import API, dumps +Example of how to download and decompress shelves from the provided account +""" +import asyncio + +from example.boilerplate import API, dumps from novelai_api.utils import compress_user_data, decompress_user_data @@ -18,4 +24,5 @@ async def main(): compress_user_data(presets) -run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/download_shelves.py b/example/download_shelves.py index 3cf7c2f..d137c6d 100644 --- a/example/download_shelves.py +++ b/example/download_shelves.py @@ -1,7 +1,13 @@ -from asyncio import run +""" +{filename} +============================================================================== -from boilerplate import API, dumps +Example of how to download and decompress shelves from the provided account +""" +import asyncio + +from example.boilerplate import API, dumps from novelai_api.utils import compress_user_data, decompress_user_data @@ -18,4 +24,5 @@ async def main(): compress_user_data(shelves) -run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/download_stories_and_content.py b/example/download_stories_and_content.py index 484c484..363937a 100644 --- a/example/download_stories_and_content.py +++ b/example/download_stories_and_content.py @@ -1,7 +1,13 @@ -from asyncio import run +""" +{filename} +============================================================================== -from boilerplate import API, dumps +Example of how to download and decrypt stories from the provided account +""" +import asyncio + +from example.boilerplate import API, dumps from novelai_api.utils import decrypt_user_data, encrypt_user_data, link_content_to_story, unlink_content_from_story @@ -33,4 +39,5 @@ async def main(): encrypt_user_data(story_contents, keystore) -run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/generate_controlnet_masks.py b/example/generate_controlnet_masks.py index 8620af6..4994ed9 100644 --- a/example/generate_controlnet_masks.py +++ b/example/generate_controlnet_masks.py @@ -1,10 +1,19 @@ +""" +{filename} +============================================================================== + +| Example of how to query the controlnet masks for an image +| +| It expects an image "results/image.png" to exist and will generate the resulting masks in this same folder +| NOTE: Currently the returned mask is wrong due to an image conversion in frontend (see :issue:`15`) +""" + import asyncio import base64 import time from pathlib import Path -from boilerplate import API - +from example.boilerplate import API from novelai_api.ImagePreset import ControlNetModel from novelai_api.NovelAIError import NovelAIError @@ -33,4 +42,5 @@ async def main(): time.sleep(5) -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/generate_image.py b/example/generate_image.py index a78b194..0094799 100644 --- a/example/generate_image.py +++ b/example/generate_image.py @@ -1,8 +1,16 @@ -from asyncio import run -from pathlib import Path +""" +{filename} +============================================================================== + +| Example of how to generate an image +| +| The resulting images will be placed in a folder named "results" +""" -from boilerplate import API +import asyncio +from pathlib import Path +from example.boilerplate import API from novelai_api.ImagePreset import ImageModel, ImagePreset, ImageResolution, UCPreset @@ -42,4 +50,5 @@ async def main(): f.write(img) -run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/generate_image_test_samplers.py b/example/generate_image_test_samplers.py index ee6c8d8..cf673b7 100644 --- a/example/generate_image_test_samplers.py +++ b/example/generate_image_test_samplers.py @@ -1,9 +1,17 @@ +""" +{filename} +============================================================================== + +| Test on which sampler currently work. It will create one image per sampler +| +| The resulting images will be placed in a folder named "results" +""" + +import asyncio import time -from asyncio import run from pathlib import Path -from boilerplate import API - +from example.boilerplate import API from novelai_api.ImagePreset import ImageModel, ImagePreset, ImageSampler from novelai_api.NovelAIError import NovelAIError @@ -33,4 +41,5 @@ async def main(): time.sleep(5) -run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/generate_image_with_controlnet.py b/example/generate_image_with_controlnet.py index 12b0f72..40e0c69 100644 --- a/example/generate_image_with_controlnet.py +++ b/example/generate_image_with_controlnet.py @@ -1,9 +1,18 @@ +""" +{filename} +============================================================================== + +| Example of how to generate an image with a Control Net +| +| The resulting image will be placed in a folder named "results" +| NOTE: Currently the returned mask is wrong due to an image conversion in frontend (see :issue:`15`) +""" + import asyncio import base64 from pathlib import Path -from boilerplate import API - +from example.boilerplate import API from novelai_api.ImagePreset import ControlNetModel, ImageModel, ImagePreset @@ -31,4 +40,5 @@ async def main(): f.write(img) -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/generate_image_with_img2img.py b/example/generate_image_with_img2img.py index d83d4bd..976d9b5 100644 --- a/example/generate_image_with_img2img.py +++ b/example/generate_image_with_img2img.py @@ -1,9 +1,17 @@ +""" +{filename} +============================================================================== + +| Example of how to generate an image with img2img +| +| The resulting image will be placed in a folder named "results" +""" + import asyncio import base64 from pathlib import Path -from boilerplate import API - +from example.boilerplate import API from novelai_api.ImagePreset import ImageGenerationType, ImageModel, ImagePreset @@ -30,4 +38,5 @@ async def main(): f.write(img) -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/generate_text.py b/example/generate_text.py index fee4453..532cc0b 100644 --- a/example/generate_text.py +++ b/example/generate_text.py @@ -1,8 +1,16 @@ -from asyncio import run -from typing import List, Optional +""" +{filename} +============================================================================== + +| Example of how to generate a text +| +| The resulting text will be directed to the standard error output (stderr) +""" -from boilerplate import API +import asyncio +from typing import List, Optional +from example.boilerplate import API from novelai_api.BanList import BanList from novelai_api.BiasGroup import BiasGroup from novelai_api.GlobalSettings import GlobalSettings @@ -102,4 +110,5 @@ async def main(): # ... and more examples can be found in tests/test_generate.py -run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/generate_voice.py b/example/generate_voice.py index 9a68565..f2894e3 100644 --- a/example/generate_voice.py +++ b/example/generate_voice.py @@ -1,7 +1,17 @@ -from asyncio import run +""" +{filename} +============================================================================== + +| Example of how to generate a voice (TTS - Text To Speech) +| +| The resulting audio sample will be placed in a folder named "results" +| The input is limited to 1000 characters (it will cut at 1000 in backend) +""" + +import asyncio from pathlib import Path -from boilerplate import API +from example.boilerplate import API # tts_file = "tts.webm" tts_file = "tts.mp3" @@ -36,4 +46,5 @@ async def main(): logger.info(f"TTS saved in {tts_file}") -run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/login.py b/example/login.py index b75273e..f3a3d52 100644 --- a/example/login.py +++ b/example/login.py @@ -1,6 +1,13 @@ -from asyncio import run +""" +{filename} +============================================================================== -from boilerplate import API +Example of how to login on the provided account +""" + +import asyncio + +from example.boilerplate import API async def main(): @@ -8,4 +15,5 @@ async def main(): print(api_handler.api.headers) -run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/login_with_proxy.py b/example/login_with_proxy.py index 3f74c30..9b43975 100644 --- a/example/login_with_proxy.py +++ b/example/login_with_proxy.py @@ -1,6 +1,13 @@ -from asyncio import run +""" +{filename} +============================================================================== -from boilerplate import API +Example of how to login on the provided account with a proxy +""" + +import asyncio + +from example.boilerplate import API async def main(): @@ -15,4 +22,5 @@ async def main(): pass -run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/suggest_tags.py b/example/suggest_tags.py index 72965cb..306d632 100644 --- a/example/suggest_tags.py +++ b/example/suggest_tags.py @@ -1,7 +1,15 @@ -from asyncio import run +""" +{filename} +============================================================================== -from boilerplate import API, dumps +| Example of tag suggestion for image gen +| +| The result will be directed to the standard error output (stderr) +""" +import asyncio + +from example.boilerplate import API, dumps from novelai_api.ImagePreset import ImageModel tags = ["gi", "bo", "scal", "cre"] @@ -19,4 +27,5 @@ async def main(): logger.info(dumps(e)) -run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/upscale_image.py b/example/upscale_image.py index 155cb5f..0468d9e 100644 --- a/example/upscale_image.py +++ b/example/upscale_image.py @@ -1,10 +1,19 @@ +""" +{filename} +============================================================================== + +| Example of how to upscale an image +| +| It expects an image "results/image.png" to exist and will generate the resulting masks in this same folder +| The image should be 512x768 by default, modify :code:`image_size` to change it +""" + import asyncio import base64 import time from pathlib import Path -from boilerplate import API - +from example.boilerplate import API from novelai_api.NovelAIError import NovelAIError @@ -35,4 +44,5 @@ async def main(): time.sleep(5) -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/novelai_api/BanList.py b/novelai_api/BanList.py index e6f093d..f30a901 100644 --- a/novelai_api/BanList.py +++ b/novelai_api/BanList.py @@ -10,6 +10,13 @@ class BanList: enabled: bool def __init__(self, *sequences: Union[List[int], str], enabled: bool = True): + """ + Create a ban list with the given elements. Elements can be string or tokenized strings + Using tokenized strings is not recommended, for flexibility between tokenizers + + :param enabled: Is the ban list enabled + """ + self.enabled = enabled self._sequences = [] @@ -20,6 +27,11 @@ def add( self, *sequences: Union[Dict[str, List[List[int]]], Dict[str, List[int]], List[int], str], ) -> "BanList": + """ + Add elements to the ban list. Elements can be string or tokenized strings + Using tokenized strings is not recommended, for flexibility between tokenizers + """ + for i, sequence in enumerate(sequences): if "sequence" in sequence: sequence = sequence["sequence"] @@ -44,15 +56,30 @@ def add( return self def __iadd__(self, o: Union[List[int], str]) -> "BanList": + """ + Add elements to the ban list. Elements can be string or tokenized strings + Using tokenized strings is not recommended, for flexibility between tokenizers + """ + self.add(o) return self def __iter__(self): + """ + Return an iterator on the stored sequences + """ + return self._sequences.__iter__() def get_tokenized_banlist(self, model: Model) -> Iterable[List[int]]: - return (tokenize_if_not(model, s) for s in self._sequences) + """ + Return the tokenized sequences for the ban list, if it is enabled + + :param model: Model to use for tokenization + """ + + return (tokenize_if_not(model, s) for s in self._sequences if self.enabled) def __str__(self) -> str: return self._sequences.__str__() diff --git a/novelai_api/BiasGroup.py b/novelai_api/BiasGroup.py index 711fd0d..4d24f6f 100644 --- a/novelai_api/BiasGroup.py +++ b/novelai_api/BiasGroup.py @@ -19,6 +19,15 @@ def __init__( generate_once: bool = False, enabled: bool = True, ): + """ + Create a bias group + + :param bias: Bias value of the bias group. Negative is a downbias, positive is an upbias + :param ensure_sequence_finish: Ensures the bias completes + :param generate_once: Only biases for the first occurrence + :param enabled: Is the bias group enabled + """ + self._sequences = [] self.bias = bias @@ -28,6 +37,10 @@ def __init__( @classmethod def from_data(cls, data: Dict[str, Any]) -> "BiasGroup": + """ + Create a bias group from bias group data + """ + # FIXME: wtf is "whenInactive" in bias ? ensure_sequence_finish = ( data["ensureSequenceFinish"] @@ -36,6 +49,7 @@ def from_data(cls, data: Dict[str, Any]) -> "BiasGroup": if "ensure_sequence_finish" in data else False ) + generate_once = ( data["generateOnce"] if "generateOnce" in data @@ -55,6 +69,11 @@ def add( self, *sequences: Union[Dict[str, List[List[int]]], Dict[str, List[int]], List[int], str], ) -> "BiasGroup": + """ + Add elements to the bias group. Elements can be string or tokenized strings + Using tokenized strings is not recommended, for flexibility between tokenizers + """ + for i, sequence in enumerate(sequences): if isinstance(sequence, dict): if "sequence" in sequence: @@ -79,12 +98,23 @@ def add( return self - def __iadd__(self, o: List[int]) -> "BiasGroup": - self.add(o) + def __iadd__( + self, sequences: Union[Dict[str, List[List[int]]], Dict[str, List[int]], List[int], str] + ) -> "BiasGroup": + """ + Add elements to the bias group. Elements can be string or tokenized strings + Using tokenized strings is not recommended, for flexibility between tokenizers + """ + + self.add(sequences) return self def __iter__(self): + """ + Return an iterator on the stored sequences + """ + return ( { "bias": self.bias, @@ -97,15 +127,21 @@ def __iter__(self): ) def get_tokenized_biases(self, model: Model) -> Iterable[Dict[str, any]]: + """ + Return the tokenized sequences for the bias group, if it is enabled + + :param model: Model to use for tokenization + """ + return ( { "bias": self.bias, "ensure_sequence_finish": self.ensure_sequence_finish, "generate_once": self.generate_once, - "enabled": self.enabled, "sequence": tokenize_if_not(model, s), } for s in self._sequences + if self.enabled ) def __str__(self) -> str: diff --git a/novelai_api/GlobalSettings.py b/novelai_api/GlobalSettings.py index 085d7fa..9348f07 100644 --- a/novelai_api/GlobalSettings.py +++ b/novelai_api/GlobalSettings.py @@ -7,6 +7,10 @@ class GlobalSettings: + """ + Object used to store global settings for the account + """ + # TODO: store bracket ban in a file _BRACKETS = { "gpt2": [ @@ -609,6 +613,7 @@ class GlobalSettings: #: Apply the GENJI_AMBIGUOUS_TOKENS if model is Genji ban_ambiguous_genji_tokens: bool + #: Value to set num_logprobs at to disable logprobs NO_LOGPROBS = -1 _settings: Dict[str, Any] @@ -649,9 +654,19 @@ def __getattr__(self, key): return object.__getattribute__(self, key) def copy(self): + """ + Create a new GlobalSettings from the current + """ + return GlobalSettings(**self._settings) def to_settings(self, model: Model) -> Dict[str, Any]: + """ + Create text generation settings from the GlobalSettings object + + :param model: Model to use the settings of + """ + settings = { "generate_until_sentence": self._settings["generate_until_sentence"], "num_logprobs": self._settings["num_logprobs"], @@ -662,6 +677,9 @@ def to_settings(self, model: Model) -> Dict[str, Any]: "use_cache": False, } + if self._settings["num_logprobs"] != self.NO_LOGPROBS: + settings["num_logprobs"] = self._settings["num_logprobs"] + tokenizer_name = Tokenizer.get_tokenizer_name(model) if self._settings["ban_brackets"]: diff --git a/novelai_api/ImagePreset.py b/novelai_api/ImagePreset.py index 4d7c54f..b8c9db3 100644 --- a/novelai_api/ImagePreset.py +++ b/novelai_api/ImagePreset.py @@ -9,6 +9,10 @@ class ImageModel(enum.Enum): + """ + Image model for low_level.suggest_tags() and low_level.generate_image() + """ + Anime_Curated = "safe-diffusion" Anime_Full = "nai-diffusion" Furry = "nai-diffusion-furry" @@ -17,6 +21,10 @@ class ImageModel(enum.Enum): class ControlNetModel(enum.Enum): + """ + ControlNet Model for ImagePreset.controlnet_model and low_level.generate_controlnet_mask() + """ + Palette_Swap = "hed" Form_Lock = "midas" Scrible = "fake_scribble" @@ -25,6 +33,10 @@ class ControlNetModel(enum.Enum): class ImageResolution(enum.Enum): + """ + Image resolution for ImagePreset.resolution + """ + Small_Portrait = (384, 640) Small_Landscape = (640, 384) Small_Square = (512, 512) @@ -39,6 +51,10 @@ class ImageResolution(enum.Enum): class ImageSampler(enum.Enum): + """ + Sampler for ImagePreset.sampler + """ + k_lms = "k_lms" k_euler = "k_euler" k_euler_ancestral = "k_euler_ancestral" @@ -59,6 +75,10 @@ class ImageSampler(enum.Enum): class UCPreset(enum.Enum): + """ + Default UC preset for ImagePreset.uc_preset + """ + Preset_Low_Quality_Bad_Anatomy = 0 Preset_Low_Quality = 1 Preset_Bad_Anatomy = 2 @@ -66,8 +86,13 @@ class UCPreset(enum.Enum): class ImageGenerationType(enum.Enum): + """ + Image generation type for low_level.generate_image + """ + NORMAL = "generate" IMG2IMG = "img2img" + # inpainting should go there class ImagePreset: @@ -182,6 +207,7 @@ class ImagePreset: "uc_preset": UCPreset.Preset_Low_Quality_Bad_Anatomy, "n_samples": 1, "seed": 0, + # TODO: set ImageSampler.k_dpmpp_2m as default ? "sampler": ImageSampler.k_euler_ancestral, "steps": 28, "scale": 11, @@ -194,7 +220,7 @@ class ImagePreset: _settings: Dict[str, Any] - # Seed provided when generating an image with seed 0 (default). Seed is also in metadata, but might be a hassle + #: Seed provided when generating an image with seed 0 (default). Seed is also in metadata, but might be a hassle last_seed: int @expand_kwargs(_TYPE_MAPPING.keys(), _TYPE_MAPPING.values()) @@ -223,15 +249,27 @@ def __delitem__(self, key): del self._settings[key] def __contains__(self, key: str): - return key in self._settings + return key in self._settings.keys() + + def update(self, values: Optional[Dict[str, Any]] = None, **kwargs) -> "ImagePreset": + """ + Update the settings stored in the preset. Works like dict.update() + """ - def update(self, values: Dict[str, Any]) -> "ImagePreset": - for k, v in values.items(): + if values is not None: + for k, v in values.items(): + self[k] = v + + for k, v in kwargs.items(): self[k] = v return self def copy(self) -> "ImagePreset": + """ + Create a new ImagePreset instance from the current one + """ + return ImagePreset(**self._settings) # give dot access capabilities to the object @@ -254,6 +292,12 @@ def __delattr__(self, name): object.__delattr__(self, name) def to_settings(self, model: ImageModel) -> Dict[str, Any]: + """ + Return the values stored in the preset, for a generate_image function + + :param model: Image model to get the settings of + """ + settings = copy.deepcopy(self._settings) resolution: Union[ImageResolution, Tuple[int, int]] = settings.pop("resolution") @@ -301,6 +345,10 @@ def to_settings(self, model: ImageModel) -> Dict[str, Any]: return settings def get_max_n_samples(self): + """ + Get the allowed max value of ImagePreset.n_samples using current preset values + """ + resolution: Union[ImageResolution, Tuple[int, int]] = self._settings["resolution"] if isinstance(resolution, ImageResolution): @@ -347,11 +395,23 @@ def calculate_cost(self, is_opus: bool): @classmethod def from_file(cls, path: str) -> "ImagePreset": + """ + Write the preset to a file + + :param path: Path to the file to read the preset from + """ + with open(path, encoding="utf-8") as f: data = json.loads(f.read()) return cls(**data) def to_file(self, path: str): + """ + Load the preset from a file + + :param path: Path to the file to write the preset to + """ + with open(path, "w", encoding="utf-8") as f: f.write(json.dumps(self._settings)) diff --git a/novelai_api/Keystore.py b/novelai_api/Keystore.py index f0ad8b2..3ddd489 100644 --- a/novelai_api/Keystore.py +++ b/novelai_api/Keystore.py @@ -68,6 +68,10 @@ def __str__(self) -> str: return str(self._keystore) def create(self) -> str: + """ + Create a new meta that is not in the keystore and assign a random nonce to it + """ + if not self._decrypted: raise ValueError("Cannot set key in an encrypted keystore") @@ -80,6 +84,12 @@ def create(self) -> str: return meta def decrypt(self, key: bytes): + """ + Decrypt the keystore. The encrypted data should be in Keystore.data + + :param key: Encryption key computed from utils.get_encryption_key() + """ + keystore = self.data.copy() if "keystore" in keystore and keystore["keystore"] is None: # keystore is null when empty @@ -123,6 +133,12 @@ def decrypt(self, key: bytes): self._decrypted = True def encrypt(self, key: bytes): + """ + Encrypt a decrypted keystore. The encrypted data will be at Keystore.data + + :param key: Encryption key computed from utils.get_encryption_key() + """ + # keystore is not decrypted, no need to encrypt it if not self._decrypted: return diff --git a/novelai_api/NovelAIError.py b/novelai_api/NovelAIError.py index b7edc7e..94881f2 100644 --- a/novelai_api/NovelAIError.py +++ b/novelai_api/NovelAIError.py @@ -1,5 +1,11 @@ class NovelAIError(Exception): + """ + Expected raised by the NAI API when a problem occurs + """ + + #: Provided status code, or -1 if no status code was provided status: int + #: Provided error message message: str def __init__(self, status: int, message: str) -> None: diff --git a/novelai_api/Preset.py b/novelai_api/Preset.py index 1dfd7cd..8f2e339 100644 --- a/novelai_api/Preset.py +++ b/novelai_api/Preset.py @@ -16,19 +16,19 @@ class Order(IntEnum): NAME_TO_ORDER = { - "tfs": Order.TFS, "temperature": Order.Temperature, - "top_p": Order.Top_P, "top_k": Order.Top_K, + "top_p": Order.Top_P, + "tfs": Order.TFS, "top_a": Order.Top_A, "typical_p": Order.Typical_P, } ORDER_TO_NAME = { - Order.TFS: "tfs", Order.Temperature: "temperature", - Order.Top_P: "top_p", Order.Top_K: "top_k", + Order.Top_P: "top_p", + Order.TFS: "tfs", Order.Top_A: "top_a", Order.Typical_P: "typical_p", } @@ -176,7 +176,10 @@ class Preset(metaclass=_PresetMetaclass): _enabled: List[bool] _settings: Dict[str, Any] + + #: Name of the preset name: str + #: Model the preset is for model: Model def __init__(self, name: str, model: Model, settings: Optional[Dict[str, Any]] = None): @@ -240,13 +243,28 @@ def __delattr__(self, name): def __repr__(self) -> str: model = self.model.value if self.model is not None else "" - return f"Preset: '{self.name} ({model})'" + enabled_keys = ", ".join(f"{k} = {v}" for k, v in zip(self._enabled, NAME_TO_ORDER.keys())) + + return f"Preset: '{self.name} ({model}, {enabled_keys})'" def enable(self, **kwargs) -> "Preset": + """ + Enable/disable the processing of sampling values (True to enable, False to disable). + + The allowed keys are : + * tfs + * temperature + * top_p + * top_k + * top_a + * typical_p + """ + for o in Order: name = ORDER_TO_NAME[o] - enabled = kwargs.pop(name, False) - self._enabled[o.value] = enabled + enabled = kwargs.pop(name, None) + if enabled is not None: + self._enabled[o.value] = enabled if len(kwargs): raise ValueError(f"Invalid order name: {', '.join(kwargs)}") @@ -254,6 +272,10 @@ def enable(self, **kwargs) -> "Preset": return self def to_settings(self) -> Dict[str, Any]: + """ + Return the values stored in the preset, for a generate function + """ + settings = deepcopy(self._settings) if "textGenerationSettingsVersion" in settings: @@ -263,33 +285,65 @@ def to_settings(self) -> Dict[str, Any]: if not self._enabled[i]: settings["order"].remove(o) + # Delete the options that return an unknown error (success status code, but server error) + if settings.get("repetition_penalty_slope", None) == 0: + del settings["repetition_penalty_slope"] + return settings def to_file(self, path: str) -> NoReturn: + """ + Write the current preset to a file + + :param path: Path to the preset file to write + """ + raise NotImplementedError() def copy(self) -> "Preset": + """ + Instantiate a new preset object from the current one + """ + return Preset(self.name, self.model, deepcopy(self._settings)) def set(self, name: str, value: Any) -> "Preset": + """ + Set a preset value. Same as `preset[name] = value` + """ + self[name] = value return self - def update(self, values: Dict[str, Any]) -> "Preset": - for k, v in values.items(): + def update(self, values: Optional[Dict[str, Any]] = None, **kwargs) -> "Preset": + """ + Update the settings stored in the preset. Works like dict.update() + """ + + if values is not None: + for k, v in values.items(): + self[k] = v + + for k, v in kwargs.items(): self[k] = v return self @classmethod def from_preset_data(cls, data: Optional[Dict[str, Any]] = None, **kwargs) -> "Preset": + """ + Instantiate a preset from preset data, the data should be the same as in a preset file. + Works like dict.update() + """ + if data is None: data = {} data.update(kwargs) name = data["name"] if "name" in data else "" + # FIXME: collapse model version model_name = data["model"] if "model" in data else "" model = Model(model_name) if enum_contains(Model, model_name) else None @@ -312,6 +366,12 @@ def from_preset_data(cls, data: Optional[Dict[str, Any]] = None, **kwargs) -> "P @classmethod def from_file(cls, path: str) -> "Preset": + """ + Instantiate a preset from the given file + + :param path: Path to the preset file + """ + with open(path, encoding="utf-8") as f: data = loads(f.read()) @@ -319,6 +379,15 @@ def from_file(cls, path: str) -> "Preset": @classmethod def from_official(cls, model: Model, name: Optional[str] = None) -> Union["Preset", None]: + """ + Return a copy of an official preset + + :param model: Model to get the preset of + :param name: Name of the preset. None means a random official preset should be returned + + :return: The chosen preset, or None if the name was not found in the list of official presets + """ + model_value: str = model.value if name is None: @@ -333,6 +402,14 @@ def from_official(cls, model: Model, name: Optional[str] = None) -> Union["Prese @classmethod def from_default(cls, model: Model) -> Union["Preset", None]: + """ + Return a copy of the default preset for the given model + + :param model: Model to get the default preset of + + :return: The chosen preset, or None if the default preset was not found for the model + """ + model_value: str = model.value default = cls._defaults.get(model_value) @@ -340,13 +417,17 @@ def from_default(cls, model: Model) -> Union["Preset", None]: return None preset = cls._officials[model_value].get(default) - if preset is None: - return None + if preset is not None: + preset = deepcopy(preset) + + return preset - return preset.copy() +def _import_officials(): + """ + Import the official presets under the 'presets' directory. Performed once, at import + """ -def import_officials(): cls = Preset cls._officials_values = {} @@ -373,4 +454,4 @@ def import_officials(): if not hasattr(Preset, "_officials"): - import_officials() + _import_officials() diff --git a/novelai_api/Tokenizer.py b/novelai_api/Tokenizer.py index f153e21..43b3bba 100644 --- a/novelai_api/Tokenizer.py +++ b/novelai_api/Tokenizer.py @@ -33,6 +33,12 @@ class Tokenizer: @classmethod def get_tokenizer_name(cls, model: Model) -> str: + """ + Get the tokenizer name a model uses + + :param model: Model to get the tokenizer name of + """ + return cls._tokenizers_name[model] _GPT2_PATH = tokenizers_path / "gpt2_tokenizer.json" @@ -56,6 +62,15 @@ def get_tokenizer_name(cls, model: Model) -> str: @classmethod def decode(cls, model: AnyModel, o: List[int]) -> str: + """ + Decode the provided tokens using the chosen tokenizer + + :param model: Model to use the tokenizer of + :param o: List of tokens to decode + + :return: Text the provided tokens decode into + """ + tokenizer_name = cls._tokenizers_name[model] tokenizer = cls._tokenizers[tokenizer_name] @@ -63,6 +78,15 @@ def decode(cls, model: AnyModel, o: List[int]) -> str: @classmethod def encode(cls, model: AnyModel, o: str) -> List[int]: + """ + Encode the provided text using the chosen tokenizer + + :param model: Model to use the tokenizer of + :param o: Text to encode + + :return: List of tokens the provided text encodes into + """ + tokenizer_name = cls._tokenizers_name[model] tokenizer = cls._tokenizers[tokenizer_name] diff --git a/novelai_api/_high_level.py b/novelai_api/_high_level.py index de9b5d6..61efaa4 100644 --- a/novelai_api/_high_level.py +++ b/novelai_api/_high_level.py @@ -1,6 +1,6 @@ import json from hashlib import sha256 -from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Tuple, Type, Union from novelai_api.BanList import BanList from novelai_api.BiasGroup import BiasGroup @@ -63,13 +63,24 @@ async def login(self, email: str, password: str) -> str: return rsp["accessToken"] async def login_from_token(self, access_key: str): + """ + Log in with the access key, instead of email and password + + :param access_key: Access key of the account (pre-computed via email and password) + + :return: User's access token + """ + rsp = await self._parent.low_level.login(access_key) self._parent.headers["Authorization"] = f"Bearer {rsp['accessToken']}" + return rsp["accessToken"] + async def get_keystore(self, key: bytes) -> Keystore: """ Retrieve the keystore and decrypt it in a readable manner. + The keystore is the mapping of meta -> encryption key of each object. If this function throws errors repeatedly at you, check your internet connection or the integrity of your keystore. @@ -86,11 +97,29 @@ async def get_keystore(self, key: bytes) -> Keystore: return keystore async def set_keystore(self, keystore: Keystore, key: bytes) -> bytes: + """ + Encrypt and upload the keystore. + + The keystore is the mapping of meta -> encryption key of each object. + If this function throws errors repeatedly at you, + check your internet connection or the integrity of your keystore. + Losing your keystore, or overwriting it means losing all content on the account. + + :param keystore: Keystore object to upload + :param key: Account's encryption key + + :return: raw data of the serialized Keystore object + """ + keystore.encrypt(key) return await self._parent.low_level.set_keystore(keystore.data) async def download_user_stories(self) -> Dict[str, Dict[str, Union[str, int]]]: + """ + Download all the objects of type 'stories' stored on the account + """ + stories = await self._parent.low_level.download_objects("stories") return stories["objects"] @@ -98,21 +127,37 @@ async def download_user_stories(self) -> Dict[str, Dict[str, Union[str, int]]]: async def download_user_story_contents( self, ) -> Dict[str, Dict[str, Union[str, int]]]: + """ + Download all the objects of type 'storycontent' stored on the account + """ + story_contents = await self._parent.low_level.download_objects("storycontent") return story_contents["objects"] async def download_user_presets(self) -> List[Dict[str, Union[str, int]]]: + """ + Download all the objects of type 'presets' stored on the account + """ + presets = await self._parent.low_level.download_objects("presets") return presets["objects"] async def download_user_modules(self) -> List[Dict[str, Union[str, int]]]: + """ + Download all the objects of type 'aimodules' stored on the account + """ + modules = await self._parent.low_level.download_objects("aimodules") return modules["objects"] async def download_user_shelves(self) -> List[Dict[str, Union[str, int]]]: + """ + Download all the objects of type 'shelf' stored on the account + """ + modules = await self._parent.low_level.download_objects("shelf") return modules["objects"] @@ -124,12 +169,11 @@ async def upload_user_content( keystore: Optional[Keystore] = None, ) -> bool: """ - Upload user content. If it has been decrypted with decrypt_user_data, - it should be re-encrypted with encrypt_user_data, even if the decryption failed + Upload user content :param data: Object to upload - :param encrypt: Encrypt/compress the data if True and not already encrypted - :param keystore: Keystore to encrypt data if encrypt is True + :param encrypt: Re-encrypt/re-compress the data, if True + :param keystore: Keystore to encrypt the data, if encrypt is True :return: True if the upload succeeded, False otherwise """ @@ -157,12 +201,19 @@ async def upload_user_content( return await self._parent.low_level.upload_object(object_type, object_id, object_meta, object_data) - async def upload_user_contents(self, datas: Iterable[Dict[str, Any]]) -> List[Tuple[str, Optional[NovelAIError]]]: + async def upload_user_contents( + self, + datas: Iterable[Dict[str, Any]], + encrypt: bool = False, + keystore: Optional[Keystore] = None, + ) -> List[Tuple[str, Optional[NovelAIError]]]: """ Upload multiple user contents. If the content has been decrypted with decrypt_user_data, it should be re-encrypted with encrypt_user_data, even if the decryption failed :param datas: Objects to upload + :param encrypt: Re-encrypt/re-compress the data, if True + :param keystore: Keystore to encrypt the data, if encrypt is True :return: A list of (id, error) of all the objects that failed to be uploaded """ @@ -171,7 +222,8 @@ async def upload_user_contents(self, datas: Iterable[Dict[str, Any]]) -> List[Tu for data in datas: try: - success = await self.upload_user_content(data) + success = await self.upload_user_content(data, encrypt, keystore) + if not success: status.append((data["id"], None)) except NovelAIError as e: @@ -192,7 +244,7 @@ async def _generate( **kwargs, ): """ - Generate content from an AI on the NovelAI server, with streaming support + Generate text with streaming support :param prompt: Context to give to the AI (raw text or list of tokens) :param model: Model to use for the AI @@ -202,14 +254,14 @@ async def _generate( :param biases: Tokens to bias (up or down) for this generation :param prefix: Module to use for this generation :param stream: Use data streaming for the response - :param kwargs: Additional parameters to pass to the requests + :param kwargs: Additional parameters to pass to the requests. Can also be used to overwrite existing parameters :return: Content that has been generated """ if preset is None: raise ValueError("Uninitialized preset") - if preset.model != model: + if preset.model is not model: raise ValueError(f"Preset '{preset.name}' (model {preset.model}) is not compatible with model {model}") preset_params = preset.to_settings() @@ -223,37 +275,23 @@ async def _generate( params["prefix"] = "vanilla" if prefix is None else prefix - if params["num_logprobs"] == GlobalSettings.NO_LOGPROBS: - del params["num_logprobs"] + for k, v, c in (("bad_words_ids", bad_words, BanList), ("logit_bias_exp", biases, BiasGroup)): + k: str + v: Union[Iterable[BanList], Iterable[BiasGroup], BanList, BiasGroup, None] + c: Union[Type[BanList], Type[BiasGroup]] - if bad_words is not None: - if isinstance(bad_words, BanList): - bad_words = [bad_words] + if v is not None: + if isinstance(v, c): + v = [v] - for i, bad_word in enumerate(bad_words): - if not isinstance(bad_word, BanList): - raise ValueError( - f"Expected type 'BanList' for item #{i} of 'bad_words', " f"but got '{type(bad_word)}'" - ) + for i, obj in enumerate(v): + if not isinstance(obj, c): + raise ValueError(f"Expected type '{c}' for item #{i} of '{k}', but got '{type(obj)}'") - params["bad_words_ids"].extend(bad_word.get_tokenized_banlist(model)) + params[k].extend(obj.get_tokenized_banlist(model)) - if biases is not None: - if isinstance(biases, BiasGroup): - biases = [biases] - - for i, bias in enumerate(biases): - if not isinstance(bias, BiasGroup): - raise ValueError(f"Expected type 'BiasGroup' for item #{i} of 'biases', but got '{type(bias)}'") - - params["logit_bias_exp"].extend(bias.get_tokenized_biases(model)) - - # Delete the options that return an unknown error (success status code, but server error) - if "repetition_penalty_slope" in params and params["repetition_penalty_slope"] == 0: - del params["repetition_penalty_slope"] - - if not params["bad_words_ids"]: - del params["bad_words_ids"] + if k in params and not params[k]: + del params[k] async for i in self._parent.low_level.generate(prompt, model, params, stream): yield i @@ -270,7 +308,7 @@ async def generate( **kwargs, ) -> Dict[str, Any]: """ - Generate text from an AI on the NovelAI server. The text is returned at once, when generation is finished. + Generate text. The text is returned at once, when generation is finished. :param prompt: Context to give to the AI (raw text or list of tokens) :param model: Model to use for the AI @@ -309,7 +347,7 @@ async def generate_stream( **kwargs, ) -> AsyncIterable[Dict[str, Any]]: """ - Generate text from an AI on the NovelAI server. The text is returned one token at a time, as it is generated. + Generate text. The text is returned one token at a time, as it is generated. :param prompt: Context to give to the AI (raw text or list of tokens) :param model: Model to use for the AI @@ -345,7 +383,7 @@ async def generate_image( **kwargs, ) -> AsyncIterable[Union[str, bytes]]: """ - Generate image from an AI on the NovelAI server + Generate one or multiple image(s) :param prompt: Prompt to give to the AI (raw text describing the wanted image) :param model: Model to use for the AI @@ -353,7 +391,7 @@ async def generate_image( :param action: Type of image generation to use :param kwargs: Additional parameters to pass to the requests. Can also be used to overwrite existing parameters - :return: Content that has been generated + :return: Pair(s) (name, image) that have been generated """ settings = preset.to_settings(model) diff --git a/novelai_api/_low_level.py b/novelai_api/_low_level.py index e59d5e2..4d530dc 100644 --- a/novelai_api/_low_level.py +++ b/novelai_api/_low_level.py @@ -119,6 +119,12 @@ async def _parse_sse_stream(rsp: ClientResponse) -> AsyncIterator[Dict[str, str] @classmethod async def _parse_response(cls, rsp: ClientResponse): + """ + Parse the content of a ClientResponse depending on the content-type + + :param rsp: ClientResponse returned by a request + """ + content_type = rsp.content_type if content_type == "application/json": @@ -247,6 +253,14 @@ async def login(self, access_key: str) -> Dict[str, str]: async def change_access_key( self, current_key: str, new_key: str, new_email: Optional[str] = None ) -> Dict[str, str]: + """ + Change the access key of the given account + + :param current_key: Current key of the account + :param new_key: New key of the account + :param new_email: New email, if it changed + """ + assert_type(str, current_key=current_key, new_key=new_key) assert_type((str, NoneType), new_email=new_email) assert_len(64, current_key=current_key, new_key=new_key) @@ -265,12 +279,24 @@ async def change_access_key( return content async def send_email_verification(self, email: str) -> bool: + """ + Send the email for account verification + + :param email: Address to send the email to + """ + assert_type(str, email=email) async for rsp, content in self.request("post", "/user/resend-email-verification", {"email": email}): return self._treat_response_bool(rsp, content, 200) async def verify_email(self, verification_token: str) -> bool: + """ + Check the token sent for email verification + + :param verification_token: Token sent to the email address + """ + assert_type(str, verification_token=verification_token) assert_len(64, verification_token=verification_token) @@ -278,6 +304,10 @@ async def verify_email(self, verification_token: str) -> bool: return self._treat_response_bool(rsp, content, 200) async def get_information(self) -> Dict[str, Any]: + """ + Get extensive information about the account + """ + async for rsp, content in self.request("get", "/user/information"): self._treat_response_object(rsp, content, 200) @@ -287,12 +317,27 @@ async def get_information(self) -> Dict[str, Any]: return content async def request_account_recovery(self, email: str) -> bool: + """ + Send a recovery token to the provided email address, if the account has been lost + **WARNING**: the content will not be readable with a different encryption key + + :param email: Address to send the email to + """ + assert_type(str, email=email) async for rsp, content in self.request("post", "/user/recovery/request", {"email": email}): return self._treat_response_bool(rsp, content, 202) async def recover_account(self, recovery_token: str, new_key: str, delete_content: bool = False) -> Dict[str, Any]: + """ + Recover the lost account + + :param recovery_token: Token sent to the given email address + :param new_key: New access key for the account + :param delete_content: Delete all content that was on the account + """ + assert_type(str, recovery_token=recovery_token, new_key=new_key) assert_type(bool, delete_content=delete_content) assert_len(16, operator.ge, recovery_token=recovery_token) @@ -313,19 +358,34 @@ async def recover_account(self, recovery_token: str, new_key: str, delete_conten return content async def delete_account(self) -> bool: + """ + Delete the account + """ + async for rsp, content in self.request("post", "/user/delete", None): return self._treat_response_bool(rsp, content, 200) async def get_data(self) -> Dict[str, Any]: + """ + Get various data about the account + """ + async for rsp, content in self.request("get", "/user/data"): self._treat_response_object(rsp, content, 200) if self.is_schema_validation_enabled: + # FIXME: doesn't seem right SchemaValidator.validate("schema_AccountInformationResponse", content) return content async def get_priority(self) -> Dict[str, Any]: + """ + Get the priority information of the account + + The priority system is a legacy system and isn't really important + """ + async for rsp, content in self.request("get", "/user/priority"): self._treat_response_object(rsp, content, 200) @@ -335,12 +395,22 @@ async def get_priority(self) -> Dict[str, Any]: return content class SubscriptionTier(enum.IntEnum): + """ + Index of the subscription tiers + + PAPER tier is the free trial + """ + PAPER = 0 TABLET = 1 SCROLL = 2 OPUS = 3 async def get_subscription(self) -> Dict[str, Any]: + """ + Get various information about the account's subscription + """ + async for rsp, content in self.request("get", "/user/subscription"): self._treat_response_object(rsp, content, 200) @@ -350,6 +420,13 @@ async def get_subscription(self) -> Dict[str, Any]: return content async def get_keystore(self) -> Dict[str, str]: + """ + Get the keystore + + The keystore is the storage for the encryption keys of any content on the account. + Losing it is equal to losing all your encrypted content + """ + async for rsp, content in self.request("get", "/user/keystore"): self._treat_response_object(rsp, content, 200) @@ -359,12 +436,25 @@ async def get_keystore(self) -> Dict[str, str]: return content async def set_keystore(self, keystore: Dict[str, str]) -> bool: + """ + Set the keystore + + The keystore is the storage for the encryption keys of any content on the account. + Losing it (or overwriting it with wrong data) is equal to losing all your encrypted content + """ + assert_type(dict, keystore=keystore) async for rsp, content in self.request("put", "/user/keystore", keystore): return self._treat_response_object(rsp, content, 200) async def download_objects(self, object_type: str) -> Dict[str, List[Dict[str, Union[str, int]]]]: + """ + Download all the objects of a given type from the account + + :param object_type: Type of the objects to download + """ + assert_type(str, object_type=object_type) async for rsp, content in self.request("get", f"/user/objects/{object_type}"): @@ -376,6 +466,14 @@ async def download_objects(self, object_type: str) -> Dict[str, List[Dict[str, U return content async def upload_objects(self, object_type: str, meta: str, data: str) -> bool: + """ + Upload multiple objects of the given type + + :param object_type: Type of the objects to upload + :param meta: Meta of the objects to upload (meta links to encryption key in keystore) + :param data: Serialized data of the content to upload + """ + assert_type(str, object_type=object_type, meta=meta, data=data) assert_len(128, operator.le, meta=meta) @@ -385,6 +483,13 @@ async def upload_objects(self, object_type: str, meta: str, data: str) -> bool: return content async def download_object(self, object_type: str, object_id: str) -> Dict[str, Union[str, int]]: + """ + Download the selected object of a given type from the account + + :param object_type: Type of the object to download + :param object_id: Id of the selected object + """ + assert_type(str, object_type=object_type, object_id=object_id) async for rsp, content in self.request("get", f"/user/objects/{object_type}/{object_id}"): @@ -396,6 +501,15 @@ async def download_object(self, object_type: str, object_id: str) -> Dict[str, U return content async def upload_object(self, object_type: str, object_id: str, meta: str, data: str) -> bool: + """ + Upload an object of then given type + + :param object_type: Type of the object to upload + :param meta: Meta of the object to upload (meta links to encryption key in keystore) + :param data: Serialized data of the content to upload + :param object_id: Id of the selected object + """ + assert_type(str, object_type=object_type, object_id=object_id, meta=meta, data=data) assert_len(128, operator.le, meta=meta) @@ -406,22 +520,41 @@ async def upload_object(self, object_type: str, object_id: str, meta: str, data: return content async def delete_object(self, object_type: str, object_id: str) -> Dict[str, Union[str, int]]: + """ + Download the selected object of a given type from the account + + :param object_type: Type of the object to delete + :param object_id: Id of the selected object + """ + assert_type(str, object_type=object_type, object_id=object_id) async for rsp, content in self.request("delete", f"/user/objects/{object_type}/{object_id}"): return self._treat_response_object(rsp, content, 200) async def get_settings(self) -> str: + """ + Get the account settings. The format is arbitrary. + """ + async for rsp, content in self.request("get", "/user/clientsettings"): return self._treat_response_object(rsp, content, 200) async def set_settings(self, value: str) -> bool: + """ + Set the account settings. The format is arbitrary. + """ + assert_type(str, value=value) async for rsp, content in self.request("put", "/user/clientsettings", value): return self._treat_response_bool(rsp, content, 200) async def bind_subscription(self, payment_processor: str, subscription_id: str) -> bool: + """ + Bind payment information to the account to renew subscription monthly + """ + assert_type(str, payment_processor=payment_processor, subscription_id=subscription_id) data = {"paymentProcessor": payment_processor, "subscriptionId": subscription_id} @@ -430,6 +563,10 @@ async def bind_subscription(self, payment_processor: str, subscription_id: str) return self._treat_response_bool(rsp, content, 201) async def change_subscription(self, new_plan: str) -> bool: + """ + Change the subscription tier. Payment information should still be bound to the account + """ + assert_type(str, new_plan=new_plan) async for rsp, content in self.request("post", "/user/subscription/change", {"newSubscriptionPlan": new_plan}): @@ -437,6 +574,8 @@ async def change_subscription(self, new_plan: str) -> bool: async def generate(self, prompt: Union[List[int], str], model: Model, params: Dict[str, Any], stream: bool = False): """ + Generate text with streaming support + :param prompt: Input to be sent the AI :param model: Model of the AI :param params: Generation parameters @@ -464,10 +603,16 @@ async def generate(self, prompt: Union[List[int], str], model: Model, params: Di yield content async def classify(self) -> NoReturn: + """ + Not implemented + """ + raise NotImplementedError("Function is not implemented yet") async def train_module(self, data: str, rate: int, steps: int, name: str, desc: str) -> Dict[str, Any]: """ + Train a module for text gen + :param data: Dataset of the module, in one single string :param rate: Learning rate of the training :param steps: Number of steps to train the module for @@ -497,7 +642,7 @@ async def train_module(self, data: str, rate: int, steps: int, name: str, desc: async def get_trained_modules(self) -> List[Dict[str, Any]]: """ - :return: List of modules trained or in training + Get the modules currently in training or that finished training """ async for rsp, content in self.request("get", "/ai/module/all"): @@ -510,9 +655,9 @@ async def get_trained_modules(self) -> List[Dict[str, Any]]: async def get_trained_module(self, module_id: str) -> Dict[str, Any]: """ - :param module_id: Id of the module + Get a module currently in training or that finished training - :return: Selected module, trained or in training + :param module_id: Id of the selected module """ assert_type(str, module_id=module_id) @@ -527,9 +672,9 @@ async def get_trained_module(self, module_id: str) -> Dict[str, Any]: async def delete_module(self, module_id: str) -> Dict[str, Any]: """ - Delete the module with id :ref: `module_id` + Delete a module currently in training or that finished training - :param module_id: Id of the module + :param module_id: Id of the selected module :return: Module that got deleted """ @@ -545,13 +690,13 @@ async def delete_module(self, module_id: str) -> Dict[str, Any]: async def generate_voice(self, text: str, seed: str, voice: int, opus: bool, version: str) -> Dict[str, Any]: """ - Generate the Text-to-Speech of :ref: `text` using the given seed and voice + Generate the Text To Speech of the given text :param text: Text to synthesize into voice (text will be cut to 1000 characters backend-side) - :param seed: Person to use the voice of + :param seed: Voice to use :param voice: Index of the voice to use :param opus: True for WebM format, False for mp3 format - :param version: Version of the TTS + :param version: Version of the TTS ("v1" or "v2") :return: TTS audio data of the text """ @@ -636,7 +781,7 @@ async def generate_image( async def generate_controlnet_mask(self, model: ControlNetModel, image: str) -> Tuple[str, bytes]: """ - Get the ControlNet's mask for the image. Used for ImageSampler["controlnet_condition"] + Get the ControlNet's mask for the given image. Used for ImageSampler.controlnet_condition :param model: ControlNet model to use :param image: b64 encoded PNG image to get the mask of @@ -656,7 +801,7 @@ async def generate_controlnet_mask(self, model: ControlNetModel, image: str) -> async def upscale_image(self, image: str, width: int, height: int, scale: int) -> Tuple[str, bytes]: """ - Upscale the image. Afaik, the only allowed values for scale are 2 and 4. + Upscale the given image. Afaik, the only allowed values for scale are 2 and 4. :param image: b64 encoded PNG image to upscale :param width: Width of the starting image diff --git a/tests/api/test_generate.py b/tests/api/test_generate.py deleted file mode 100644 index 41a9f40..0000000 --- a/tests/api/test_generate.py +++ /dev/null @@ -1,398 +0,0 @@ -import asyncio -from logging import Logger, StreamHandler -from os import environ as env -from typing import Tuple - -import pytest -from aiohttp import ClientConnectionError, ClientPayloadError, ClientSession - -from novelai_api import NovelAIAPI, NovelAIError -from novelai_api.BanList import BanList -from novelai_api.BiasGroup import BiasGroup -from novelai_api.GlobalSettings import GlobalSettings -from novelai_api.Preset import Model, Preset -from novelai_api.Tokenizer import Tokenizer -from novelai_api.utils import b64_to_tokens - -# Text generation length -GENERATION_LENGTH = 10 - -WORKERS = int(env["PYTEST_XDIST_WORKER_COUNT"]) if "PYTEST_XDIST_WORKER_COUNT" in env else 1 -PROXY = env["NAI_PROXY"] if "NAI_PROXY" in env else None - -# Minimum time for a test (in seconds) -MIN_TEST_TIME = 2 * WORKERS - - -async def run_test(func, *args, is_async: bool, attempts: int = 5): - async def test_func(): - if is_async: - try: - async with ClientSession() as test_session: - api = NovelAIAPI(test_session) - api.proxy = PROXY - - return await func(api, *args) - except Exception as test_exc: - await test_session.close() - raise test_exc - - else: - api = NovelAIAPI() - api.proxy = PROXY - - return await func(api, *args) - - err: Exception = RuntimeError("Unknown error") - for _ in range(attempts): - try: - # inject api and execute the test - return await asyncio.gather(test_func(), asyncio.sleep(MIN_TEST_TIME)) - - except (ClientConnectionError, asyncio.TimeoutError, ClientPayloadError) as e: - err = e - retry = True - - except NovelAIError as e: - err = e - retry = any( - [ - e.status == 502, # Bad Gateway - e.status == 520, # Cloudflare Unknown Error - e.status == 524, # Cloudflare Gateway Error - ] - ) - - if not retry: - break - - await asyncio.sleep(10) - - # ping every 5 mins until connection is re-established - async with ClientSession() as session: - while True: - try: - rsp = await session.get("https://www.google.com", timeout=5 * 60) - rsp.raise_for_status() - - break - except ClientConnectionError: - await asyncio.sleep(5 * 60) - except asyncio.TimeoutError: - pass - - raise err - - -def permutations(*args): - args = [list(a) for a in args if len(a)] - l = len(args) - ilist = [0] * l - - while True: - yield [arg[i] for arg, i in zip(args, ilist)] - - ilist[0] += 1 - for i in range(l): - if ilist[i] == len(args[i]): - if i + 1 == l: # end, don't overflow - return - - ilist[i + 1] += 1 - ilist[i] = 0 - else: - break - - -if "NAI_USERNAME" not in env or "NAI_PASSWORD" not in env: - raise RuntimeError("Please ensure that NAI_USERNAME and NAI_PASSWORD are set in your environment") - -username = env["NAI_USERNAME"] -password = env["NAI_PASSWORD"] - -logger = Logger("NovelAI") -logger.addHandler(StreamHandler()) - -input_txt = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam at dolor dictum, interdum est sed, consequat arcu. Pellentesque in massa eget lorem fermentum placerat in pellentesque purus. Suspendisse potenti. Integer interdum, felis quis porttitor volutpat, est mi rutrum massa, venenatis viverra neque lectus semper metus. Pellentesque in neque arcu. Ut at arcu blandit purus aliquet finibus. Suspendisse laoreet risus a gravida semper. Aenean scelerisque et sem vitae feugiat. Quisque et interdum diam, eu vehicula felis. Ut tempus quam eros, et sollicitudin ligula auctor at. Integer at tempus dui, quis pharetra purus. Duis venenatis tincidunt tellus nec efficitur. Nam at malesuada ligula." # noqa: E501 # pylint: disable=C0301 -prompts = [input_txt] -tokenize_prompt = [False, True] - -models = [*Model] -# NOTE: uncomment that if you're not Opus -# models.remove(Model.Genji) -# models.remove(Model.Snek) - -models_presets = [(model, preset) for model in models for preset in Preset[model]] - -model_input_permutation = [*permutations(models, prompts, tokenize_prompt)] -model_preset_input_permutation = [*permutations(models_presets, prompts, tokenize_prompt)] - - -async def simple_generate(api: NovelAIAPI, model: Model, preset: Preset, prompt: str, tokenize: bool): - await api.high_level.login(username, password) - - logger.info("Using model %s, preset %s\n", model.value, preset.name) - - if tokenize: - prompt = Tokenizer.encode(model, prompt) - - global_settings = GlobalSettings() - gen = await api.high_level.generate(prompt, model, preset, global_settings) - logger.info(gen) - logger.info(Tokenizer.decode(model, b64_to_tokens(gen["output"]))) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_simple_generate_sync(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # sync handler - await run_test(simple_generate, *model_preset, prompt, tokenize, is_async=False) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_simple_generate_async(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # async handler - await run_test(simple_generate, *model_preset, prompt, tokenize, is_async=True) - - -async def default_generate(api: NovelAIAPI, model: Model, prompt: str, tokenize: bool): - await api.high_level.login(username, password) - - preset = Preset.from_default(model) - preset["max_length"] = GENERATION_LENGTH - - logger.info("Using model %s, preset %s\n", model.value, preset.name) - - if tokenize: - prompt = Tokenizer.encode(model, prompt) - - global_settings = GlobalSettings() - gen = await api.high_level.generate(prompt, model, preset, global_settings) - logger.info(gen) - logger.info(Tokenizer.decode(model, b64_to_tokens(gen["output"]))) - - -@pytest.mark.parametrize("model,prompt,tokenize", model_input_permutation) -async def test_default_generate_sync(model: Model, prompt: str, tokenize: bool): - # sync handler - await run_test(default_generate, model, prompt, tokenize, is_async=False) - - -@pytest.mark.parametrize("model,prompt,tokenize", model_input_permutation) -async def test_default_generate_async(model: Model, prompt: str, tokenize: bool): - # async handler - await run_test(default_generate, model, prompt, tokenize, is_async=True) - - -async def official_generate(api: NovelAIAPI, model: Model, prompt: str, tokenize: bool): - await api.high_level.login(username, password) - - preset = Preset.from_official(model) - preset["max_length"] = GENERATION_LENGTH - - logger.info("Using model %s, preset %s\n", model.value, preset.name) - - if tokenize: - prompt = Tokenizer.encode(model, prompt) - - global_settings = GlobalSettings() - gen = await api.high_level.generate(prompt, model, preset, global_settings) - logger.info(gen) - logger.info(Tokenizer.decode(model, b64_to_tokens(gen["output"]))) - - -@pytest.mark.parametrize("model,prompt,tokenize", model_input_permutation) -async def test_official_generate_sync(model: Model, prompt: str, tokenize: bool): - # sync handler - await run_test(official_generate, model, prompt, tokenize, is_async=False) - - -@pytest.mark.parametrize("model,prompt,tokenize", model_input_permutation) -async def test_official_generate_async(model: Model, prompt: str, tokenize: bool): - # async handler - await run_test(official_generate, model, prompt, tokenize, is_async=True) - - -async def globalsettings_generate(api: NovelAIAPI, model: Model, preset: Preset, prompt: str, tokenize: bool): - await api.high_level.login(username, password) - preset["max_length"] = GENERATION_LENGTH - - logger.info("Using model {model.value}, preset {preset.name}\n") - - if tokenize: - prompt = Tokenizer.encode(model, prompt) - - global_settings = GlobalSettings( - bias_dinkus_asterism=True, - ban_brackets=True, - num_logprobs=GlobalSettings.NO_LOGPROBS, - ) - - gen = await api.high_level.generate(prompt, model, preset, global_settings) - logger.info(gen) - logger.info(Tokenizer.decode(model, b64_to_tokens(gen["output"]))) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_globalsettings_generate_sync(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # sync handler - await run_test(globalsettings_generate, *model_preset, prompt, tokenize, is_async=False) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_globalsettings_generate_async(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # async handler - await run_test(globalsettings_generate, *model_preset, prompt, tokenize, is_async=True) - - -async def bias_generate(api: NovelAIAPI, model: Model, preset: Preset, prompt: str, tokenize: bool): - await api.high_level.login(username, password) - preset["max_length"] = GENERATION_LENGTH - - logger.info("Using model %s, preset %s\n", model.value, preset.name) - - if tokenize: - prompt = Tokenizer.encode(model, prompt) - - global_settings = GlobalSettings() - global_settings["bias_dinkus_asterism"] = True - global_settings["ban_brackets"] = True - global_settings["num_logprobs"] = 1 - - bias1 = ( - BiasGroup(-0.1) - .add("It is", " It is", "It was", " It was", Tokenizer.encode(model, "There is")) - .add(Tokenizer.encode(model, "There are")) - ) - bias1 += " as it is" - - bias2 = BiasGroup(0.1).add(" because", " since").add(" why").add(" when", " about") - bias2 += "as it is" - - gen = await api.high_level.generate(prompt, model, preset, global_settings, biases=(bias1, bias2)) - logger.info(gen) - logger.info(Tokenizer.decode(model, b64_to_tokens(gen["output"]))) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_bias_generate_sync(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # sync handler - await run_test(bias_generate, *model_preset, prompt, tokenize, is_async=False) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_bias_generate_async(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # async handler - await run_test(bias_generate, *model_preset, prompt, tokenize, is_async=True) - - -async def ban_generate(api: NovelAIAPI, model: Model, preset: Preset, prompt: str, tokenize: bool): - await api.high_level.login(username, password) - preset["max_length"] = GENERATION_LENGTH - - logger.info("Using model %s, preset %s\n", model.value, preset.name) - - if tokenize: - prompt = Tokenizer.encode(model, prompt) - - global_settings = GlobalSettings() - global_settings["bias_dinkus_asterism"] = True - global_settings["ban_brackets"] = True - global_settings["num_logprobs"] = 1 - - banned = BanList().add("***", "---", Tokenizer.encode(model, "///")).add("fairly") - banned += "commonly" - banned += " commonly" - - gen = await api.high_level.generate(prompt, model, preset, global_settings, bad_words=banned) - logger.info(gen) - logger.info(Tokenizer.decode(model, b64_to_tokens(gen["output"]))) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_ban_generate_sync(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # sync handler - await run_test(ban_generate, *model_preset, prompt, tokenize, is_async=False) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_ban_generate_async(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # async handler - await run_test(ban_generate, *model_preset, prompt, tokenize, is_async=True) - - -async def ban_and_bias_generate(api: NovelAIAPI, model: Model, preset: Preset, prompt: str, tokenize: bool): - await api.high_level.login(username, password) - preset["max_length"] = GENERATION_LENGTH - - logger.info("Using model %s, preset %s\n", model.value, preset.name) - - if tokenize: - prompt = Tokenizer.encode(model, prompt) - - global_settings = GlobalSettings() - global_settings["bias_dinkus_asterism"] = True - global_settings["ban_brackets"] = True - global_settings["num_logprobs"] = 1 - - banned = BanList().add("***", "---", Tokenizer.encode(model, "///")).add("fairly") - banned += "commonly" - banned += " commonly" - - bias2 = BiasGroup(0.1).add(" because", " since").add(" why").add(" when", " about") - bias2 += "as it is" - - gen = await api.high_level.generate(prompt, model, preset, global_settings, bad_words=banned, biases=[bias2]) - logger.info(gen) - logger.info(Tokenizer.decode(model, b64_to_tokens(gen["output"]))) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_ban_and_bias_generate_sync(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # sync handler - await run_test(ban_and_bias_generate, *model_preset, prompt, tokenize, is_async=False) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_ban_and_bias_generate_async(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # async handler - await run_test(ban_and_bias_generate, *model_preset, prompt, tokenize, is_async=True) - - -async def ban_and_bias_generate_streaming(api: NovelAIAPI, model: Model, preset: Preset, prompt: str, tokenize: bool): - await api.high_level.login(username, password) - preset["max_length"] = GENERATION_LENGTH - - logger.info("Using model %s, preset %s\n", model.value, preset.name) - - if tokenize: - prompt = Tokenizer.encode(model, prompt) - - global_settings = GlobalSettings() - global_settings["bias_dinkus_asterism"] = True - global_settings["ban_brackets"] = True - global_settings["num_logprobs"] = 1 - - banned = BanList().add("***", "---", Tokenizer.encode(model, "///")).add("fairly") - banned += "commonly" - banned += " commonly" - - bias2 = BiasGroup(0.1).add(" because", " since").add(" why").add(" when", " about") - bias2 += "as it is" - - async for i in api.high_level.generate_stream( - prompt, model, preset, global_settings, bad_words=banned, biases=[bias2] - ): - logger.info(i) - logger.info(Tokenizer.decode(model, b64_to_tokens(i["token"]))) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_ban_and_bias_generate_streaming_sync(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # sync handler - await run_test(ban_and_bias_generate_streaming, *model_preset, prompt, tokenize, is_async=False) - - -@pytest.mark.parametrize("model_preset,prompt,tokenize", model_preset_input_permutation) -async def test_ban_and_bias_generate_streaming_async(model_preset: Tuple[Model, Preset], prompt: str, tokenize: bool): - # async handler - await run_test(ban_and_bias_generate_streaming, *model_preset, prompt, tokenize, is_async=True) diff --git a/tests/api/test_generate_parallel.py b/tests/api/test_generate_parallel.py deleted file mode 100644 index 40c1425..0000000 --- a/tests/api/test_generate_parallel.py +++ /dev/null @@ -1,75 +0,0 @@ -# Test to ensure the wrapper works with parallelism, do not spam the API ! - -from asyncio import gather, run -from logging import Logger, StreamHandler -from os import environ as env - -import pytest -from aiohttp import ClientSession - -from novelai_api import NovelAIAPI -from novelai_api.GlobalSettings import GlobalSettings -from novelai_api.Preset import Model, Preset -from novelai_api.Tokenizer import Tokenizer -from novelai_api.utils import b64_to_tokens - -models = [Model.Sigurd] - -if "NAI_USERNAME" not in env or "NAI_PASSWORD" not in env: - raise RuntimeError("Please ensure that NAI_USERNAME and NAI_PASSWORD are set in your environment") - -username = env["NAI_USERNAME"] -password = env["NAI_PASSWORD"] -PROXY = env["NAI_PROXY"] if "NAI_PROXY" in env else None - -logger = Logger("NovelAI") -logger.addHandler(StreamHandler()) - - -async def generate_5(api: NovelAIAPI, model: Model): - api.proxy = PROXY - await api.high_level.login(username, password) - - preset = Preset.from_default(model) - global_settings = GlobalSettings(ban_brackets=True, bias_dinkus_asterism=True) - - logger.info("Using model %s\n", model.value) - - input_txt = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam at dolor dictum, interdum est sed, consequat arcu. Pellentesque in massa eget lorem fermentum placerat in pellentesque purus. Suspendisse potenti. Integer interdum, felis quis porttitor volutpat, est mi rutrum massa, venenatis viverra neque lectus semper metus. Pellentesque in neque arcu. Ut at arcu blandit purus aliquet finibus. Suspendisse laoreet risus a gravida semper. Aenean scelerisque et sem vitae feugiat. Quisque et interdum diam, eu vehicula felis. Ut tempus quam eros, et sollicitudin ligula auctor at. Integer at tempus dui, quis pharetra purus. Duis venenatis tincidunt tellus nec efficitur. Nam at malesuada ligula." # noqa: E501 # pylint: disable=C0301 - prompt = Tokenizer.encode(model, input_txt) - - preset["max_length"] = 20 - gens = [api.high_level.generate(prompt, model, preset, global_settings) for _ in range(5)] - results = await gather(*gens) - for i, gen in enumerate(results): - logger.info("Gen %s:", i) - logger.info("\t%s", Tokenizer.decode(model, b64_to_tokens(gen["output"]))) - logger.info("") - - -@pytest.mark.parametrize("model", models) -async def test_run_5_generate_sync(model: Model): - # sync handler - api = NovelAIAPI() - await generate_5(api, model) - - -@pytest.mark.parametrize("model", models) -async def test_run_5_generate_async(model: Model): - # async handler - try: - async with ClientSession() as session: - api = NovelAIAPI(session) - await generate_5(api, model) - except Exception as e: - await session.close() - raise e - - -if __name__ == "__main__": - - async def main(): - await test_run_5_generate_sync(Model.Sigurd) - await test_run_5_generate_async(Model.Sigurd) - - run(main())