Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests refactor #19

Merged
merged 6 commits into from
Apr 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ jobs:
python-version: [3.7, 3.8, 3.9, "3.10", 3.11]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install dependencies
run: |
pip install nox
Expand All @@ -40,11 +41,12 @@ jobs:
python-version: ["3.11"]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install dependencies
run: |
pip install nox
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ Download via [pip](https://pypi.org/project/novelai-api):
pip install novelai-api
```

A full list of examples is available in the [example](/example) directory
A full list of examples is available in the [example](example) directory

The API works through the NovelAIAPI object.
It is split in 2 groups: NovelAIAPI.low_level and NovelAIAPI.high_level

## low_level
The low level interface is a strict implementation of the official API (<https://api.novelai.net/docs>).
It only checks for input types via assert and output schema if NovelAIAPI.low_level.is_schema_validation_enabled is True
It only checks for input types via assert, and output schema if NovelAIAPI.low_level.is_schema_validation_enabled is True

## high_level
The high level interface builds on the low level one for easier handling of complex settings.
Expand Down
7 changes: 0 additions & 7 deletions docs/requirements.txt

This file was deleted.

33 changes: 30 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import os
import sys
from pathlib import Path
from types import ModuleType
from typing import List
from types import FunctionType
from typing import List, Union

from sphinx.application import Sphinx
from sphinx.ext.autodoc import Options
Expand All @@ -32,6 +32,7 @@

extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.intersphinx",
"sphinx.ext.extlinks",
"sphinx.ext.viewcode",
"myst_parser",
Expand All @@ -40,6 +41,8 @@
"hoverxref.extension",
]

add_module_names = False

autodoc_class_signature = "separated"
autodoc_member_order = "bysource"
autodoc_typehints_format = "fully-qualified"
Expand Down Expand Up @@ -81,7 +84,11 @@
# -- Hooks -------------------------------------------------------------------


def format_docstring(_app: Sphinx, what: str, name: str, obj: ModuleType, _options: Options, lines: List[str]):
def format_docstring(_app: Sphinx, what: str, name: str, obj, _options: Options, lines: List[str]):
"""
Inject metadata in docstrings if necessary
"""

kwargs = {
"obj_type": what,
"obj_name": name,
Expand All @@ -99,5 +106,25 @@ def format_docstring(_app: Sphinx, what: str, name: str, obj: ModuleType, _optio
lines[i] = line.format(**kwargs)


def hide_test_signature(
_app: Sphinx,
what: str,
name: str,
_obj: FunctionType,
_options: Options,
signature: str,
return_annotation: Union[str, None],
):
if what == "function":
module_name, *_, file_name, _func_name = name.split(".")

# erase signature for functions from test files
if module_name == "tests" and file_name.startswith("test_"):
return "", None

return signature, return_annotation


def setup(app):
app.connect("autodoc-process-docstring", format_docstring)
app.connect("autodoc-process-signature", hide_test_signature)
9 changes: 9 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ example
:maxdepth: 2

example/example


API
---

.. toctree::
:maxdepth: 2

tests/api/api
7 changes: 7 additions & 0 deletions docs/source/tests/api/api.boilerplate.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
boilerplate
===========

.. automodule:: tests.api.boilerplate
:members:
:undoc-members:
:show-inheritance:
47 changes: 47 additions & 0 deletions docs/source/tests/api/api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
API directory
=============

Requirements
------------


Usage
-----


Content
-------

test_decrypt_encrypt_integrity_check.py
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: tests.api.test_decrypt_encrypt_integrity_check
:members:

test_imagegen_samplers.py
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: tests.api.test_imagegen_samplers
:members:

test_sync_gen.py
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: tests.api.test_sync_gen
:members:

test_textgen_presets.py
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: tests.api.test_textgen_presets
:members:

test_textgen_sanity.py
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: tests.api.test_textgen_sanity
:members:


Reference
---------

.. toctree::
:maxdepth: 2

api.boilerplate
48 changes: 45 additions & 3 deletions novelai_api/Preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ class Order(IntEnum):
}


def enum_contains(enum_class: EnumMeta, value) -> bool:
def enum_contains(enum_class: EnumMeta, value: str) -> bool:
"""
Check if the value provided is valid for the enum

:param enum_class: Class of the Enum
:param value: Value to check
"""

if not hasattr(enum_class, "enum_member_values"):
enum_class.enum_member_values = list(e.value for e in enum_class)

Expand All @@ -45,6 +52,33 @@ def enum_contains(enum_class: EnumMeta, value) -> bool:
return value in values


def _strip_model_version(value: str) -> str:
parts = value.split("-")

if parts[-1].startswith("v") and parts[-1][1:].isdecimal():
parts = parts[:-1]

return "-".join(parts)


def collapse_model(enum_class: EnumMeta, value: str):
"""
Collapse multiple version of a model to the last model value

:param enum_class: Class of the Enum
:param value: Value of the model to collapse
"""

if not hasattr(enum_class, "enum_member_values"):
enum_class.enum_member_values = {_strip_model_version(e.value): e for e in enum_class}

values = enum_class.enum_member_values
if len(values) == 0:
raise ValueError(f"Empty enum class: '{enum_class}'")

return values.get(_strip_model_version(value))


class StrEnum(str, Enum):
pass

Expand Down Expand Up @@ -281,11 +315,19 @@ def to_settings(self) -> Dict[str, Any]:
if "textGenerationSettingsVersion" in settings:
del settings["textGenerationSettingsVersion"] # not API relevant

# remove disabled sampling options
for i, o in enumerate(Order):
if not self._enabled[i]:
settings["order"].remove(o)
settings.pop(ORDER_TO_NAME[o], None)

settings["order"] = [e.value for e in settings["order"]]

# seems that 0 doesn't disable it, but does weird things
if settings.get("repetition_penalty_range", None) == 0:
del settings["repetition_penalty_range"]

# Delete the options that return an unknown error (success status code, but server error)
# delete the options that return an unknown error (success status code, but server error)
if settings.get("repetition_penalty_slope", None) == 0:
del settings["repetition_penalty_slope"]

Expand Down Expand Up @@ -345,7 +387,7 @@ def from_preset_data(cls, data: Optional[Dict[str, Any]] = None, **kwargs) -> "P

# FIXME: collapse model version
model_name = data["model"] if "model" in data else ""
model = Model(model_name) if enum_contains(Model, model_name) else None
model = collapse_model(Model, model_name)

settings = data["parameters"] if "parameters" in data else {}

Expand Down
6 changes: 6 additions & 0 deletions novelai_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
"""
:class:`NovelAI_API`

:class:`NovelAIError`
"""

from novelai_api.NovelAI_API import NovelAIAPI
from novelai_api.NovelAIError import NovelAIError
5 changes: 5 additions & 0 deletions novelai_api/_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ async def _generate(
params.update(global_params)
params.update(kwargs)

# adjust repetition penalty value for Sigurd and Euterpe
if model in (Model.Sigurd, Model.Euterpe) and "repetition_penalty" in params:
rep_pen = params["repetition_penalty"]
params["repetition_penalty"] = (0.525 * (rep_pen - 1) / 7) + 1

params["prefix"] = "vanilla" if prefix is None else prefix

for k, v, c in (("bad_words_ids", bad_words, BanList), ("logit_bias_exp", biases, BiasGroup)):
Expand Down
13 changes: 7 additions & 6 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_dotenv(session: nox.Session):
return json.loads(dotenv_str)


def install_package(session: nox.Session, *packages: str, dev: bool = False):
def install_package(session: nox.Session, *packages: str, dev: bool = False, docs: bool = False):
session.install("poetry")
session.install("python-dotenv")

Expand All @@ -40,6 +40,8 @@ def install_package(session: nox.Session, *packages: str, dev: bool = False):
poetry_groups = []
if dev:
poetry_groups.extend(["--with", "dev"])
if docs:
poetry_groups.extend(["--with", "docs"])

session.run("python", "-m", "poetry", "export", "--output=requirements.txt", "--without-hashes", *poetry_groups)
session.run("python", "-m", "poetry", "build", "--format=wheel")
Expand Down Expand Up @@ -67,7 +69,7 @@ def pre_commit(session: nox.Session):
@nox.session(py=test_py_versions, name="test-mock")
def test_mock(session: nox.Session):
install_package(session, dev=True)
session.run("pytest", "--tb=short", "-n", "auto", "tests/mock/")
session.run("pytest", "tests/mock/")


@nox.session(py=test_py_versions, name="test-api")
Expand All @@ -76,9 +78,9 @@ def test_api(session: nox.Session):
session.run("npm", "install", "fflate", external=True)

if session.posargs:
session.run("pytest", "--tb=short", *(f"tests/api/{e}" for e in session.posargs))
session.run("pytest", *(f"tests/api/{e}" for e in session.posargs))
else:
session.run("pytest", "--tb=short", "tests/api/")
session.run("pytest", "tests/api/")


@nox.session()
Expand All @@ -99,8 +101,7 @@ def run(session: nox.Session):
def build_docs(session: nox.Session):
docs_path = pathlib.Path(__file__).parent / "docs"

install_package(session)
session.install("-r", str(docs_path / "requirements.txt"))
install_package(session, dev=True, docs=True)

with session.chdir(docs_path):
session.run("make", "html", external=True)
Expand Down
23 changes: 21 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ classifiers = [
]
packages = [{include = "novelai_api"}]

[tool.poetry.urls]
"Bug Tracker" = "https://github.com/Aedial/novelai-api/issues"

[tool.poetry.dependencies]
python = ">=3.7.2,<3.12"
aiohttp = {extras = ["speedups"], version = "^3.8.3"}
Expand All @@ -26,12 +29,22 @@ regex = "^2022.10.31"
sentencepiece = "^0.1.98"

[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
pytest-asyncio = "^0.20.1"
pytest-xdist = "^3.0.2"
pytest-randomly = "^3.12.0"
pylint = "^2.15.5"

[tool.poetry.group.docs.dependencies]
sphinx = "^5.3.0"
# patched repo to work with relative links
myst_parser = {git = "https://github.com/Aedial/MyST-Parser", rev = "adcdb9a"}
linkify-it-py = "^2.0.0"
sphinx-copybutton = "^0.5.2"
sphinx-last-updated-by-git = "^0.3.4"
sphinx-hoverxref = "^1.3.0"

[tool.flake8]
# TODO: add flake when supports come
# TODO: add flake when supports come (https://github.com/PyCQA/flake8/issues/234)

[tool.bandit]
exclude_dirs = ["tests/api/test_decrypt_encrypt_integrity_check.py"]
Expand Down Expand Up @@ -78,6 +91,12 @@ from pylint.config import find_default_config_files
path.extend(p.parent for p in find_default_config_files())
"""

[tool.pytest.ini_options]
xfail_strict = true
empty_parameter_set_mark = "fail_at_collect"
asyncio_mode = "auto"
addopts = "--tb=short -vv"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Empty file added tests/__init__.py
Empty file.
Empty file added tests/api/__init__.py
Empty file.
Loading