Skip to content

Commit

Permalink
[TEST] Fix tests and rep pen behavior
Browse files Browse the repository at this point in the history
Apparenly, top_k of 1 is not determistic...
On another note, why can't rep pen adustement be backend ?
  • Loading branch information
Aedial committed Apr 30, 2023
1 parent 0dc489f commit e76020c
Show file tree
Hide file tree
Showing 14 changed files with 1,561 additions and 1,724 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ API
---

.. toctree::
:maxdepth: 3
:maxdepth: 2

tests/api/api
48 changes: 45 additions & 3 deletions novelai_api/Preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ class Order(IntEnum):
}


def enum_contains(enum_class: EnumMeta, value) -> bool:
def enum_contains(enum_class: EnumMeta, value: str) -> bool:
"""
Check if the value provided is valid for the enum
:param enum_class: Class of the Enum
:param value: Value to check
"""

if not hasattr(enum_class, "enum_member_values"):
enum_class.enum_member_values = list(e.value for e in enum_class)

Expand All @@ -45,6 +52,33 @@ def enum_contains(enum_class: EnumMeta, value) -> bool:
return value in values


def _strip_model_version(value: str) -> str:
parts = value.split("-")

if parts[-1].startswith("v") and parts[-1][1:].isdecimal():
parts = parts[:-1]

return "-".join(parts)


def collapse_model(enum_class: EnumMeta, value: str):
"""
Collapse multiple version of a model to the last model value
:param enum_class: Class of the Enum
:param value: Value of the model to collapse
"""

if not hasattr(enum_class, "enum_member_values"):
enum_class.enum_member_values = {_strip_model_version(e.value): e for e in enum_class}

values = enum_class.enum_member_values
if len(values) == 0:
raise ValueError(f"Empty enum class: '{enum_class}'")

return values.get(_strip_model_version(value))


class StrEnum(str, Enum):
pass

Expand Down Expand Up @@ -281,11 +315,19 @@ def to_settings(self) -> Dict[str, Any]:
if "textGenerationSettingsVersion" in settings:
del settings["textGenerationSettingsVersion"] # not API relevant

# remove disabled sampling options
for i, o in enumerate(Order):
if not self._enabled[i]:
settings["order"].remove(o)
settings.pop(ORDER_TO_NAME[o], None)

settings["order"] = [e.value for e in settings["order"]]

# seems that 0 doesn't disable it, but does weird things
if settings.get("repetition_penalty_range", None) == 0:
del settings["repetition_penalty_range"]

# Delete the options that return an unknown error (success status code, but server error)
# delete the options that return an unknown error (success status code, but server error)
if settings.get("repetition_penalty_slope", None) == 0:
del settings["repetition_penalty_slope"]

Expand Down Expand Up @@ -345,7 +387,7 @@ def from_preset_data(cls, data: Optional[Dict[str, Any]] = None, **kwargs) -> "P

# FIXME: collapse model version
model_name = data["model"] if "model" in data else ""
model = Model(model_name) if enum_contains(Model, model_name) else None
model = collapse_model(Model, model_name)

settings = data["parameters"] if "parameters" in data else {}

Expand Down
5 changes: 5 additions & 0 deletions novelai_api/_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ async def _generate(
params.update(global_params)
params.update(kwargs)

# adjust repetition penalty value for Sigurd and Euterpe
if model in (Model.Sigurd, Model.Euterpe) and "repetition_penalty" in params:
rep_pen = params["repetition_penalty"]
params["repetition_penalty"] = (0.525 * (rep_pen - 1) / 7) + 1

params["prefix"] = "vanilla" if prefix is None else prefix

for k, v, c in (("bad_words_ids", bad_words, BanList), ("logit_bias_exp", biases, BiasGroup)):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ path.extend(p.parent for p in find_default_config_files())
xfail_strict = true
empty_parameter_set_mark = "fail_at_collect"
asyncio_mode = "auto"
addopts = "--tb=short"
addopts = "--tb=short -vv"

[build-system]
requires = ["poetry-core"]
Expand Down
10 changes: 5 additions & 5 deletions tests/api/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

def error_handler(func_ext: Optional[Callable[[Any, Any], Awaitable[Any]]] = None, *, attempts: int = 5, wait: int = 5):
"""
Add error handling to the function ``func_ext`` or ``func``
The function must accept a NovelAIAPI object as first arguments
Decorator to add error handling to the decorated function
The function must accept an API object as first arguments
:param func_ext: Substitute for func if the decorator is run without argument
:param func_ext: Substitute for func if the decorator is run without argument. Do not provide it directly
:param attempts: Number of attempts to do before raising the error
:param wait: Time (in seconds) to wait after each call
"""
Expand Down Expand Up @@ -146,7 +146,7 @@ def dumps(e: Any) -> str:
@pytest.fixture(scope="session")
async def api_handle():
"""
API handle for an Async Test
API handle for an Async Test. Use it as a pytest fixture
"""

async with API() as api:
Expand All @@ -156,7 +156,7 @@ async def api_handle():
@pytest.fixture(scope="session")
async def api_handle_sync():
"""
API handle for a Sync Test
API handle for a Sync Test. Use it as a pytest fixture
"""

async with API(sync=True) as api:
Expand Down
Loading

0 comments on commit e76020c

Please sign in to comment.