Skip to content

Commit

Permalink
[TEST] Fix tests
Browse files Browse the repository at this point in the history
Added error handling
Cleaned things
Added tests and descriptions
/\!\ Sanity still fails /\!\
  • Loading branch information
Aedial committed Apr 27, 2023
1 parent 9b994a6 commit 0522e9d
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 71 deletions.
Empty file added tests/api/__init__.py
Empty file.
118 changes: 65 additions & 53 deletions tests/api/boilerplate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import functools
import json
from logging import Logger, StreamHandler
from os import environ as env
from typing import Any, NoReturn
from typing import Any, Awaitable, Callable, NoReturn, Optional

import pytest
from aiohttp import ClientConnectionError, ClientPayloadError, ClientSession
Expand Down Expand Up @@ -57,58 +58,69 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
if not self._sync:
await self._session.__aexit__(exc_type, exc_val, exc_tb)

async def run_test(self, func, *args, attempts: int = 5, wait: int = 5):
"""
Run the function ``func`` with the provided arguments and retry on error handling
The function must accept a NovelAIAPI object as first arguments
:param func: Function to run
:param args: Arguments to provide to the function
:param attempts: Number of attempts to do before raising the error
:param wait: Time (in seconds) to wait after each call
"""

err: Exception = RuntimeError("Error placeholder. Shouldn't happen")
for _ in range(attempts):
try:
res = await func(self.api, *args)
await asyncio.sleep(wait)

return res
except (ClientConnectionError, asyncio.TimeoutError, ClientPayloadError) as e:
err = e
retry = True

except NovelAIError as e:
err = e
retry = any(
[
e.status == 502, # Bad Gateway
e.status == 520, # Cloudflare Unknown Error
e.status == 524, # Cloudflare Gateway Error
]
)

if not retry:
break

# 10s wait between each retry
await asyncio.sleep(10)

# no internet: ping every 5 mins until connection is re-established
async with ClientSession() as session:
while True:
try:
rsp = await session.get("https://www.google.com", timeout=5 * 60)
rsp.raise_for_status()

break
except ClientConnectionError:
await asyncio.sleep(5 * 60)
except asyncio.TimeoutError:
pass

raise err

def error_handler(func_ext: Optional[Callable[[Any, Any], Awaitable[Any]]] = None, *, attempts: int = 5, wait: int = 5):
"""
Add error handling to the function ``func_ext`` or ``func``
The function must accept a NovelAIAPI object as first arguments
:param func_ext: Substitute for func if the decorator is run without argument
:param attempts: Number of attempts to do before raising the error
:param wait: Time (in seconds) to wait after each call
"""

def decorator(func: Callable[[Any, Any], Awaitable[Any]]):
@functools.wraps(func)
async def wrap(*args, **kwargs):
err: Exception = RuntimeError("Error placeholder. Shouldn't happen")
for _ in range(attempts):
try:
res = await func(*args, **kwargs)
await asyncio.sleep(wait)

return res
except (ClientConnectionError, asyncio.TimeoutError, ClientPayloadError) as e:
err = e
retry = True

except NovelAIError as e:
err = e
retry = any(
[
e.status == 502, # Bad Gateway
e.status == 520, # Cloudflare Unknown Error
e.status == 524, # Cloudflare Gateway Error
]
)

if not retry:
break

# 10s wait between each retry
await asyncio.sleep(10)

# no internet: ping every 5 mins until connection is re-established
async with ClientSession() as session:
while True:
try:
rsp = await session.get("https://www.google.com", timeout=5 * 60)
rsp.raise_for_status()

break
except ClientConnectionError:
await asyncio.sleep(5 * 60)
except asyncio.TimeoutError:
pass

raise err

return wrap

# allow to run the function without argument
if func_ext is None:
return decorator

return decorator(func_ext)


class JSONEncoder(json.JSONEncoder):
Expand Down
3 changes: 3 additions & 0 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def pytest_terminal_summary(terminalreporter):
terminalreporter.line("")


# TODO: add html reporting


# cannot put in boilerplate because pytest is a mess
@pytest.fixture(scope="session")
def event_loop():
Expand Down
54 changes: 54 additions & 0 deletions tests/api/test_imagegen_samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
{filename}
==============================================================================
Test which samplers currently work
"""

import itertools
from typing import Tuple

import pytest

from novelai_api import NovelAIError
from novelai_api.ImagePreset import ImageModel, ImagePreset, ImageSampler, UCPreset
from tests.api.boilerplate import api_handle, error_handler # noqa: F401 # pylint: disable=W0611

sampler_xfail = pytest.mark.xfail(True, raises=NovelAIError, reason="The sampler doesn't currently work")

models = list(ImageModel)
models.remove(ImageModel.Anime_Inpainting)

samplers = list(ImageSampler)
model_samplers = list(itertools.product(models, samplers))


@pytest.mark.parametrize(
"model_sampler",
[
pytest.param(e, marks=sampler_xfail) if e[1] in (ImageSampler.nai_smea, ImageSampler.plms) else e
for e in model_samplers
],
)
@error_handler
async def test_samplers(
api_handle, model_sampler: Tuple[ImageModel, ImagePreset] # noqa: F811 # pylint: disable=W0621
):
"""
Test the presets to ensure they work with the API
"""

api = api_handle.api
model, sampler = model_sampler

logger = api_handle.logger
logger.info(f"Testing model {model} with sampler {sampler}")

preset = ImagePreset(sampler=sampler)

# Furry doesn't have UCPreset.Preset_Low_Quality_Bad_Anatomy
if model is ImageModel.Furry:
preset.uc_preset = UCPreset.Preset_Low_Quality

async for _, _ in api.high_level.generate_image("1girl", model, preset):
pass
12 changes: 9 additions & 3 deletions tests/api/test_sync_gen.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
"""
Test if sync capabilities work without problem
This test only checks if sync works, not if the result is right. It's the job of the other tests
{filename}
==============================================================================
| Test if sync capabilities work without problem
| This test only checks if sync works, not if the result is right, it's the job of the other tests
"""

from novelai_api.GlobalSettings import GlobalSettings
from novelai_api.Preset import Model, Preset
from novelai_api.Tokenizer import Tokenizer
from novelai_api.utils import b64_to_tokens, decrypt_user_data
from tests.api.boilerplate import api_handle_sync # noqa: F401 # pylint: disable=W0611
from tests.api.boilerplate import api_handle_sync, error_handler # noqa: F401 # pylint: disable=W0611

prompt = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam at dolor dictum, interdum est sed, consequat arcu. Pellentesque in massa eget lorem fermentum placerat in pellentesque purus. Suspendisse potenti. Integer interdum, felis quis porttitor volutpat, est mi rutrum massa, venenatis viverra neque lectus semper metus. Pellentesque in neque arcu. Ut at arcu blandit purus aliquet finibus. Suspendisse laoreet risus a gravida semper. Aenean scelerisque et sem vitae feugiat. Quisque et interdum diam, eu vehicula felis. Ut tempus quam eros, et sollicitudin ligula auctor at. Integer at tempus dui, quis pharetra purus. Duis venenatis tincidunt tellus nec efficitur. Nam at malesuada ligula." # noqa: E501 # pylint: disable=C0301
model = Model.Krake


@error_handler
async def test_is_reachable(api_handle_sync): # noqa: F811 # pylint: disable=W0621
assert await api_handle_sync.api.low_level.is_reachable() is True


@error_handler
async def test_download(api_handle_sync): # noqa: F811 # pylint: disable=W0621
key = api_handle_sync.encryption_key
keystore = await api_handle_sync.api.high_level.get_keystore(key)
modules = await api_handle_sync.api.high_level.download_user_modules()
decrypt_user_data(modules, keystore)


@error_handler
async def test_generate(api_handle_sync): # noqa: F811 # pylint: disable=W0621
api = api_handle_sync.api

Expand Down
44 changes: 33 additions & 11 deletions tests/api/test_textgen_presets.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
"""
{filename}
==============================================================================
Tests pertaining to the Preset class
"""

from typing import Tuple

import pytest

from novelai_api import NovelAIAPI
from novelai_api.GlobalSettings import GlobalSettings
from novelai_api.Preset import Model, Preset
from novelai_api.Tokenizer import Tokenizer
from novelai_api.utils import b64_to_tokens
from tests.api.boilerplate import api_handle # noqa: F401 # pylint: disable=W0611
from tests.api.boilerplate import api_handle, error_handler # noqa: F401 # pylint: disable=W0611

prompt = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam at dolor dictum, interdum est sed, consequat arcu. Pellentesque in massa eget lorem fermentum placerat in pellentesque purus. Suspendisse potenti. Integer interdum, felis quis porttitor volutpat, est mi rutrum massa, venenatis viverra neque lectus semper metus. Pellentesque in neque arcu. Ut at arcu blandit purus aliquet finibus. Suspendisse laoreet risus a gravida semper. Aenean scelerisque et sem vitae feugiat. Quisque et interdum diam, eu vehicula felis. Ut tempus quam eros, et sollicitudin ligula auctor at. Integer at tempus dui, quis pharetra purus. Duis venenatis tincidunt tellus nec efficitur. Nam at malesuada ligula." # noqa: E501 # pylint: disable=C0301
models = [*Model]
models = list(Model)
# NOTE: uncomment that if you're not Opus
# models.remove(Model.Genji)
# models.remove(Model.Snek)

models_presets = [(model, preset) for model in models for preset in Preset[model]]
models_presets_default = [(model, Preset.from_default(model)) for model in models]


async def simple_generate(api: NovelAIAPI, model: Model, preset: Preset):
@pytest.mark.parametrize("model_preset", models_presets)
@error_handler
async def test_presets(api_handle, model_preset: Tuple[Model, Preset]): # noqa: F811 # pylint: disable=W0621
"""
Test the presets to ensure they work with the API
"""

api = api_handle.api
model, preset = model_preset

logger = api.logger
logger.info("Using model %s, preset %s\n", model.value, preset.name)

Expand All @@ -29,11 +43,19 @@ async def simple_generate(api: NovelAIAPI, model: Model, preset: Preset):
logger.info(Tokenizer.decode(model, b64_to_tokens(gen["output"])))


@pytest.mark.parametrize("model_preset", models_presets)
async def test_presets(api_handle, model_preset: Tuple[Model, Preset]): # noqa: F811 # pylint: disable=W0621
await api_handle.api.run_test(simple_generate, *model_preset)
@pytest.mark.parametrize("model", models)
async def preset_from_default(model: Model):
"""
Test the from_default constructor of Preset
"""

Preset.from_default(model)


@pytest.mark.parametrize("model", models)
async def preset_from_official(model: Model):
"""
Test the from_official constructor of Preset
"""

@pytest.mark.parametrize("model_preset", models_presets_default)
async def test_presets_default(api_handle, model_preset: Tuple[Model, Preset]): # noqa: F811 # pylint: disable=W0621
await api_handle.api.run_test(simple_generate, *model_preset)
Preset.from_official(model)
15 changes: 11 additions & 4 deletions tests/api/test_textgen_sanity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
{filename}
==============================================================================
Test if the generated content is consistent with the frontend
"""

import json
from pathlib import Path
from typing import Any, Dict, Tuple
Expand All @@ -8,7 +15,7 @@
from novelai_api.Preset import Model, Preset
from novelai_api.Tokenizer import Tokenizer
from novelai_api.utils import b64_to_tokens
from tests.api.boilerplate import api_handle # noqa: F401 # pylint: disable=W0611
from tests.api.boilerplate import api_handle, error_handler # noqa: F401 # pylint: disable=W0611

models = [*Model]
# NOTE: uncomment that if you're not Opus
Expand All @@ -22,8 +29,8 @@
model_configs = [(model, p) for model in models for p in (config_path / model.value).iterdir()]


# In case of error, the config path will be in the dump, as an argument
@pytest.mark.parametrize("model_config", model_configs)
@error_handler
async def test_generate(api_handle, model_config: Tuple[Model, Path]): # noqa: F811 # pylint: disable=W0621
api = api_handle.api
logger = api.logger
Expand All @@ -33,7 +40,7 @@ async def test_generate(api_handle, model_config: Tuple[Model, Path]): # noqa:

missing_keys = {"prompt", "preset", "global_settings"} - set(config.keys())
if missing_keys:
raise ValueError(f"Config missing keys {', '.join(missing_keys)}")
raise ValueError(f"Config {path} missing keys {', '.join(missing_keys)}")

prompt = config["prompt"]
preset_data = config["preset"]
Expand All @@ -47,7 +54,7 @@ async def test_generate(api_handle, model_config: Tuple[Model, Path]): # noqa:
biases = None # TODO
module = config.get("module", None)

logger.info("Using model %s, preset %s\n", model.value, preset.name)
logger.info("Using model %s, preset %s (%s)\n", model.value, preset.name, path)

gen = await api.high_level.generate(prompt, model, preset, global_settings, bans, biases, module)
# logger.info(gen)
Expand Down

0 comments on commit 0522e9d

Please sign in to comment.