From 0522e9d393179a127e19a6fc07800f0da7828f9c Mon Sep 17 00:00:00 2001 From: Aedial Date: Thu, 27 Apr 2023 19:59:31 +0200 Subject: [PATCH] [TEST] Fix tests Added error handling Cleaned things Added tests and descriptions /\!\ Sanity still fails /\!\ --- tests/api/__init__.py | 0 tests/api/boilerplate.py | 118 +++++++++++++++------------- tests/api/conftest.py | 3 + tests/api/test_imagegen_samplers.py | 54 +++++++++++++ tests/api/test_sync_gen.py | 12 ++- tests/api/test_textgen_presets.py | 44 ++++++++--- tests/api/test_textgen_sanity.py | 15 +++- 7 files changed, 175 insertions(+), 71 deletions(-) create mode 100644 tests/api/__init__.py create mode 100644 tests/api/test_imagegen_samplers.py diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/api/boilerplate.py b/tests/api/boilerplate.py index c50c097..37cc9cb 100644 --- a/tests/api/boilerplate.py +++ b/tests/api/boilerplate.py @@ -1,8 +1,9 @@ import asyncio +import functools import json from logging import Logger, StreamHandler from os import environ as env -from typing import Any, NoReturn +from typing import Any, Awaitable, Callable, NoReturn, Optional import pytest from aiohttp import ClientConnectionError, ClientPayloadError, ClientSession @@ -57,58 +58,69 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): if not self._sync: await self._session.__aexit__(exc_type, exc_val, exc_tb) - async def run_test(self, func, *args, attempts: int = 5, wait: int = 5): - """ - Run the function ``func`` with the provided arguments and retry on error handling - The function must accept a NovelAIAPI object as first arguments - - :param func: Function to run - :param args: Arguments to provide to the function - :param attempts: Number of attempts to do before raising the error - :param wait: Time (in seconds) to wait after each call - """ - - err: Exception = RuntimeError("Error placeholder. Shouldn't happen") - for _ in range(attempts): - try: - res = await func(self.api, *args) - await asyncio.sleep(wait) - - return res - except (ClientConnectionError, asyncio.TimeoutError, ClientPayloadError) as e: - err = e - retry = True - - except NovelAIError as e: - err = e - retry = any( - [ - e.status == 502, # Bad Gateway - e.status == 520, # Cloudflare Unknown Error - e.status == 524, # Cloudflare Gateway Error - ] - ) - - if not retry: - break - - # 10s wait between each retry - await asyncio.sleep(10) - - # no internet: ping every 5 mins until connection is re-established - async with ClientSession() as session: - while True: - try: - rsp = await session.get("https://www.google.com", timeout=5 * 60) - rsp.raise_for_status() - - break - except ClientConnectionError: - await asyncio.sleep(5 * 60) - except asyncio.TimeoutError: - pass - - raise err + +def error_handler(func_ext: Optional[Callable[[Any, Any], Awaitable[Any]]] = None, *, attempts: int = 5, wait: int = 5): + """ + Add error handling to the function ``func_ext`` or ``func`` + The function must accept a NovelAIAPI object as first arguments + + :param func_ext: Substitute for func if the decorator is run without argument + :param attempts: Number of attempts to do before raising the error + :param wait: Time (in seconds) to wait after each call + """ + + def decorator(func: Callable[[Any, Any], Awaitable[Any]]): + @functools.wraps(func) + async def wrap(*args, **kwargs): + err: Exception = RuntimeError("Error placeholder. Shouldn't happen") + for _ in range(attempts): + try: + res = await func(*args, **kwargs) + await asyncio.sleep(wait) + + return res + except (ClientConnectionError, asyncio.TimeoutError, ClientPayloadError) as e: + err = e + retry = True + + except NovelAIError as e: + err = e + retry = any( + [ + e.status == 502, # Bad Gateway + e.status == 520, # Cloudflare Unknown Error + e.status == 524, # Cloudflare Gateway Error + ] + ) + + if not retry: + break + + # 10s wait between each retry + await asyncio.sleep(10) + + # no internet: ping every 5 mins until connection is re-established + async with ClientSession() as session: + while True: + try: + rsp = await session.get("https://www.google.com", timeout=5 * 60) + rsp.raise_for_status() + + break + except ClientConnectionError: + await asyncio.sleep(5 * 60) + except asyncio.TimeoutError: + pass + + raise err + + return wrap + + # allow to run the function without argument + if func_ext is None: + return decorator + + return decorator(func_ext) class JSONEncoder(json.JSONEncoder): diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 92e6b11..e560eeb 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -21,6 +21,9 @@ def pytest_terminal_summary(terminalreporter): terminalreporter.line("") +# TODO: add html reporting + + # cannot put in boilerplate because pytest is a mess @pytest.fixture(scope="session") def event_loop(): diff --git a/tests/api/test_imagegen_samplers.py b/tests/api/test_imagegen_samplers.py new file mode 100644 index 0000000..a5300d4 --- /dev/null +++ b/tests/api/test_imagegen_samplers.py @@ -0,0 +1,54 @@ +""" +{filename} +============================================================================== + +Test which samplers currently work +""" + +import itertools +from typing import Tuple + +import pytest + +from novelai_api import NovelAIError +from novelai_api.ImagePreset import ImageModel, ImagePreset, ImageSampler, UCPreset +from tests.api.boilerplate import api_handle, error_handler # noqa: F401 # pylint: disable=W0611 + +sampler_xfail = pytest.mark.xfail(True, raises=NovelAIError, reason="The sampler doesn't currently work") + +models = list(ImageModel) +models.remove(ImageModel.Anime_Inpainting) + +samplers = list(ImageSampler) +model_samplers = list(itertools.product(models, samplers)) + + +@pytest.mark.parametrize( + "model_sampler", + [ + pytest.param(e, marks=sampler_xfail) if e[1] in (ImageSampler.nai_smea, ImageSampler.plms) else e + for e in model_samplers + ], +) +@error_handler +async def test_samplers( + api_handle, model_sampler: Tuple[ImageModel, ImagePreset] # noqa: F811 # pylint: disable=W0621 +): + """ + Test the presets to ensure they work with the API + """ + + api = api_handle.api + model, sampler = model_sampler + + logger = api_handle.logger + logger.info(f"Testing model {model} with sampler {sampler}") + + preset = ImagePreset(sampler=sampler) + + # Furry doesn't have UCPreset.Preset_Low_Quality_Bad_Anatomy + if model is ImageModel.Furry: + preset.uc_preset = UCPreset.Preset_Low_Quality + + async for _, _ in api.high_level.generate_image("1girl", model, preset): + pass diff --git a/tests/api/test_sync_gen.py b/tests/api/test_sync_gen.py index cf09a5a..65eb662 100644 --- a/tests/api/test_sync_gen.py +++ b/tests/api/test_sync_gen.py @@ -1,22 +1,27 @@ """ -Test if sync capabilities work without problem -This test only checks if sync works, not if the result is right. It's the job of the other tests +{filename} +============================================================================== + +| Test if sync capabilities work without problem +| This test only checks if sync works, not if the result is right, it's the job of the other tests """ from novelai_api.GlobalSettings import GlobalSettings from novelai_api.Preset import Model, Preset from novelai_api.Tokenizer import Tokenizer from novelai_api.utils import b64_to_tokens, decrypt_user_data -from tests.api.boilerplate import api_handle_sync # noqa: F401 # pylint: disable=W0611 +from tests.api.boilerplate import api_handle_sync, error_handler # noqa: F401 # pylint: disable=W0611 prompt = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam at dolor dictum, interdum est sed, consequat arcu. Pellentesque in massa eget lorem fermentum placerat in pellentesque purus. Suspendisse potenti. Integer interdum, felis quis porttitor volutpat, est mi rutrum massa, venenatis viverra neque lectus semper metus. Pellentesque in neque arcu. Ut at arcu blandit purus aliquet finibus. Suspendisse laoreet risus a gravida semper. Aenean scelerisque et sem vitae feugiat. Quisque et interdum diam, eu vehicula felis. Ut tempus quam eros, et sollicitudin ligula auctor at. Integer at tempus dui, quis pharetra purus. Duis venenatis tincidunt tellus nec efficitur. Nam at malesuada ligula." # noqa: E501 # pylint: disable=C0301 model = Model.Krake +@error_handler async def test_is_reachable(api_handle_sync): # noqa: F811 # pylint: disable=W0621 assert await api_handle_sync.api.low_level.is_reachable() is True +@error_handler async def test_download(api_handle_sync): # noqa: F811 # pylint: disable=W0621 key = api_handle_sync.encryption_key keystore = await api_handle_sync.api.high_level.get_keystore(key) @@ -24,6 +29,7 @@ async def test_download(api_handle_sync): # noqa: F811 # pylint: disable=W0621 decrypt_user_data(modules, keystore) +@error_handler async def test_generate(api_handle_sync): # noqa: F811 # pylint: disable=W0621 api = api_handle_sync.api diff --git a/tests/api/test_textgen_presets.py b/tests/api/test_textgen_presets.py index 10dce37..ce26100 100644 --- a/tests/api/test_textgen_presets.py +++ b/tests/api/test_textgen_presets.py @@ -1,25 +1,39 @@ +""" +{filename} +============================================================================== + +Tests pertaining to the Preset class +""" + from typing import Tuple import pytest -from novelai_api import NovelAIAPI from novelai_api.GlobalSettings import GlobalSettings from novelai_api.Preset import Model, Preset from novelai_api.Tokenizer import Tokenizer from novelai_api.utils import b64_to_tokens -from tests.api.boilerplate import api_handle # noqa: F401 # pylint: disable=W0611 +from tests.api.boilerplate import api_handle, error_handler # noqa: F401 # pylint: disable=W0611 prompt = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam at dolor dictum, interdum est sed, consequat arcu. Pellentesque in massa eget lorem fermentum placerat in pellentesque purus. Suspendisse potenti. Integer interdum, felis quis porttitor volutpat, est mi rutrum massa, venenatis viverra neque lectus semper metus. Pellentesque in neque arcu. Ut at arcu blandit purus aliquet finibus. Suspendisse laoreet risus a gravida semper. Aenean scelerisque et sem vitae feugiat. Quisque et interdum diam, eu vehicula felis. Ut tempus quam eros, et sollicitudin ligula auctor at. Integer at tempus dui, quis pharetra purus. Duis venenatis tincidunt tellus nec efficitur. Nam at malesuada ligula." # noqa: E501 # pylint: disable=C0301 -models = [*Model] +models = list(Model) # NOTE: uncomment that if you're not Opus # models.remove(Model.Genji) # models.remove(Model.Snek) models_presets = [(model, preset) for model in models for preset in Preset[model]] -models_presets_default = [(model, Preset.from_default(model)) for model in models] -async def simple_generate(api: NovelAIAPI, model: Model, preset: Preset): +@pytest.mark.parametrize("model_preset", models_presets) +@error_handler +async def test_presets(api_handle, model_preset: Tuple[Model, Preset]): # noqa: F811 # pylint: disable=W0621 + """ + Test the presets to ensure they work with the API + """ + + api = api_handle.api + model, preset = model_preset + logger = api.logger logger.info("Using model %s, preset %s\n", model.value, preset.name) @@ -29,11 +43,19 @@ async def simple_generate(api: NovelAIAPI, model: Model, preset: Preset): logger.info(Tokenizer.decode(model, b64_to_tokens(gen["output"]))) -@pytest.mark.parametrize("model_preset", models_presets) -async def test_presets(api_handle, model_preset: Tuple[Model, Preset]): # noqa: F811 # pylint: disable=W0621 - await api_handle.api.run_test(simple_generate, *model_preset) +@pytest.mark.parametrize("model", models) +async def preset_from_default(model: Model): + """ + Test the from_default constructor of Preset + """ + + Preset.from_default(model) + +@pytest.mark.parametrize("model", models) +async def preset_from_official(model: Model): + """ + Test the from_official constructor of Preset + """ -@pytest.mark.parametrize("model_preset", models_presets_default) -async def test_presets_default(api_handle, model_preset: Tuple[Model, Preset]): # noqa: F811 # pylint: disable=W0621 - await api_handle.api.run_test(simple_generate, *model_preset) + Preset.from_official(model) diff --git a/tests/api/test_textgen_sanity.py b/tests/api/test_textgen_sanity.py index 99337bf..789e154 100644 --- a/tests/api/test_textgen_sanity.py +++ b/tests/api/test_textgen_sanity.py @@ -1,3 +1,10 @@ +""" +{filename} +============================================================================== + +Test if the generated content is consistent with the frontend +""" + import json from pathlib import Path from typing import Any, Dict, Tuple @@ -8,7 +15,7 @@ from novelai_api.Preset import Model, Preset from novelai_api.Tokenizer import Tokenizer from novelai_api.utils import b64_to_tokens -from tests.api.boilerplate import api_handle # noqa: F401 # pylint: disable=W0611 +from tests.api.boilerplate import api_handle, error_handler # noqa: F401 # pylint: disable=W0611 models = [*Model] # NOTE: uncomment that if you're not Opus @@ -22,8 +29,8 @@ model_configs = [(model, p) for model in models for p in (config_path / model.value).iterdir()] -# In case of error, the config path will be in the dump, as an argument @pytest.mark.parametrize("model_config", model_configs) +@error_handler async def test_generate(api_handle, model_config: Tuple[Model, Path]): # noqa: F811 # pylint: disable=W0621 api = api_handle.api logger = api.logger @@ -33,7 +40,7 @@ async def test_generate(api_handle, model_config: Tuple[Model, Path]): # noqa: missing_keys = {"prompt", "preset", "global_settings"} - set(config.keys()) if missing_keys: - raise ValueError(f"Config missing keys {', '.join(missing_keys)}") + raise ValueError(f"Config {path} missing keys {', '.join(missing_keys)}") prompt = config["prompt"] preset_data = config["preset"] @@ -47,7 +54,7 @@ async def test_generate(api_handle, model_config: Tuple[Model, Path]): # noqa: biases = None # TODO module = config.get("module", None) - logger.info("Using model %s, preset %s\n", model.value, preset.name) + logger.info("Using model %s, preset %s (%s)\n", model.value, preset.name, path) gen = await api.high_level.generate(prompt, model, preset, global_settings, bans, biases, module) # logger.info(gen)