From 682d8d6e5f07e58e44bef2ef3a6336d783dc3d46 Mon Sep 17 00:00:00 2001 From: Aedial Date: Sun, 21 May 2023 20:01:08 +0200 Subject: [PATCH] [API] Fix tests with models without preset, and Tokenizer's docstring Because sphinx seems to consider <|endoftext|> to be an anchor or something --- novelai_api/Preset.py | 2 ++ novelai_api/Tokenizer.py | 2 +- novelai_api/_low_level.py | 17 +++++++++++++++++ tests/api/test_textgen_sanity.py | 6 +++++- 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/novelai_api/Preset.py b/novelai_api/Preset.py index 24adc10..ae1eacc 100644 --- a/novelai_api/Preset.py +++ b/novelai_api/Preset.py @@ -521,6 +521,8 @@ def _import_officials(): path = pathlib.Path(__file__).parent / "presets" / f"presets_{model.value.replace('-', '_')}" if not path.exists(): warnings.warn(f"Missing preset folder for model {model.value}") + cls._officials_values[model.value] = [] + cls._officials[model.value] = {} continue if (path / "default.txt").exists(): diff --git a/novelai_api/Tokenizer.py b/novelai_api/Tokenizer.py index 7a76a1c..fd3967e 100644 --- a/novelai_api/Tokenizer.py +++ b/novelai_api/Tokenizer.py @@ -26,7 +26,7 @@ def __init__(self, model_path: str): def encode(self, s: str) -> List[int]: """ Encode the provided text using the SentencePiece tokenizer. - This workaround is needed because sentencepiece cannot handle <|endoftext|> + This workaround is needed because sentencepiece cannot handle `<|endoftext|>` :param s: Text to encode diff --git a/novelai_api/_low_level.py b/novelai_api/_low_level.py index 4d530dc..4fc8c5d 100644 --- a/novelai_api/_low_level.py +++ b/novelai_api/_low_level.py @@ -1,5 +1,7 @@ +import copy import enum import io +import json import operator import zipfile from typing import Any, AsyncIterator, Dict, List, NoReturn, Optional, Tuple, Union @@ -20,6 +22,21 @@ SSE_FIELDS = ["event", "data", "id", "retry"] +def print_with_parameters(args: Dict[str, Any]): + """ + Print the provided parameters in a nice way + """ + + a = copy.deepcopy(args) + if "input" in a: + a["input"] = f"{a['input'][:10]}...{a['input'][-10:]}" if 30 < len(a["input"]) else a["input"] + + if "parameters" in a: + a["parameters"] = {k: str(v) for k, v in a["parameters"].items()} + + print(json.dumps(a, indent=4)) + + # === API === # class LowLevel: _parent: "NovelAIAPI" # noqa: F821 diff --git a/tests/api/test_textgen_sanity.py b/tests/api/test_textgen_sanity.py index 30485a3..c94228c 100644 --- a/tests/api/test_textgen_sanity.py +++ b/tests/api/test_textgen_sanity.py @@ -23,7 +23,11 @@ models = list(set(models) - {Model.Genji, Model.Snek, Model.HypeBot, Model.Inline}) config_path = Path(__file__).parent / "sanity_text_sets" -model_configs = [(model, p) for model in models for p in (config_path / model.value).iterdir()] + +model_configs = [] +for model_dir in config_path.iterdir(): + m = Model(model_dir.stem) + model_configs.extend([(m, p) for p in model_dir.iterdir()]) @pytest.mark.parametrize("model_config", model_configs)