Skip to content

Commit

Permalink
[API] Fix tests with models without preset, and Tokenizer's docstring
Browse files Browse the repository at this point in the history
Because sphinx seems to consider <|endoftext|> to be
an anchor or something
  • Loading branch information
Aedial committed May 21, 2023
1 parent 7b3dc23 commit 682d8d6
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 2 deletions.
2 changes: 2 additions & 0 deletions novelai_api/Preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ def _import_officials():
path = pathlib.Path(__file__).parent / "presets" / f"presets_{model.value.replace('-', '_')}"
if not path.exists():
warnings.warn(f"Missing preset folder for model {model.value}")
cls._officials_values[model.value] = []
cls._officials[model.value] = {}
continue

if (path / "default.txt").exists():
Expand Down
2 changes: 1 addition & 1 deletion novelai_api/Tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, model_path: str):
def encode(self, s: str) -> List[int]:
"""
Encode the provided text using the SentencePiece tokenizer.
This workaround is needed because sentencepiece cannot handle <|endoftext|>
This workaround is needed because sentencepiece cannot handle `<|endoftext|>`
:param s: Text to encode
Expand Down
17 changes: 17 additions & 0 deletions novelai_api/_low_level.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import enum
import io
import json
import operator
import zipfile
from typing import Any, AsyncIterator, Dict, List, NoReturn, Optional, Tuple, Union
Expand All @@ -20,6 +22,21 @@
SSE_FIELDS = ["event", "data", "id", "retry"]


def print_with_parameters(args: Dict[str, Any]):
"""
Print the provided parameters in a nice way
"""

a = copy.deepcopy(args)
if "input" in a:
a["input"] = f"{a['input'][:10]}...{a['input'][-10:]}" if 30 < len(a["input"]) else a["input"]

if "parameters" in a:
a["parameters"] = {k: str(v) for k, v in a["parameters"].items()}

print(json.dumps(a, indent=4))


# === API === #
class LowLevel:
_parent: "NovelAIAPI" # noqa: F821
Expand Down
6 changes: 5 additions & 1 deletion tests/api/test_textgen_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
models = list(set(models) - {Model.Genji, Model.Snek, Model.HypeBot, Model.Inline})

config_path = Path(__file__).parent / "sanity_text_sets"
model_configs = [(model, p) for model in models for p in (config_path / model.value).iterdir()]

model_configs = []
for model_dir in config_path.iterdir():
m = Model(model_dir.stem)
model_configs.extend([(m, p) for p in model_dir.iterdir()])


@pytest.mark.parametrize("model_config", model_configs)
Expand Down

0 comments on commit 682d8d6

Please sign in to comment.