Skip to content

Commit

Permalink
[API] Collapse model value for Preset
Browse files Browse the repository at this point in the history
  • Loading branch information
Aedial committed Apr 27, 2023
1 parent 3436fef commit 2b48641
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ Download via [pip](https://pypi.org/project/novelai-api):
pip install novelai-api
```

A full list of examples is available in the [example](/example) directory
A full list of examples is available in the [example](example) directory

The API works through the NovelAIAPI object.
It is split in 2 groups: NovelAIAPI.low_level and NovelAIAPI.high_level

## low_level
The low level interface is a strict implementation of the official API (<https://api.novelai.net/docs>).
It only checks for input types via assert and output schema if NovelAIAPI.low_level.is_schema_validation_enabled is True
It only checks for input types via assert, and output schema if NovelAIAPI.low_level.is_schema_validation_enabled is True

## high_level
The high level interface builds on the low level one for easier handling of complex settings.
Expand Down
38 changes: 36 additions & 2 deletions novelai_api/Preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ class Order(IntEnum):
}


def enum_contains(enum_class: EnumMeta, value) -> bool:
def enum_contains(enum_class: EnumMeta, value: str) -> bool:
"""
Check if the value provided is valid for the enum
:param enum_class: Class of the Enum
:param value: Value to check
"""

if not hasattr(enum_class, "enum_member_values"):
enum_class.enum_member_values = list(e.value for e in enum_class)

Expand All @@ -45,6 +52,33 @@ def enum_contains(enum_class: EnumMeta, value) -> bool:
return value in values


def _strip_model_version(value: str) -> str:
parts = value.split("-")

if parts[-1].startswith("v") and parts[-1][1:].isdecimal():
parts = parts[:-1]

return "-".join(parts)


def collapse_model(enum_class: EnumMeta, value: str):
"""
Collapse multiple version of a model to the last model value
:param enum_class: Class of the Enum
:param value: Value of the model to collapse
"""

if not hasattr(enum_class, "enum_member_values"):
enum_class.enum_member_values = {_strip_model_version(e.value): e for e in enum_class}

values = enum_class.enum_member_values
if len(values) == 0:
raise ValueError(f"Empty enum class: '{enum_class}'")

return values.get(_strip_model_version(value))


class StrEnum(str, Enum):
pass

Expand Down Expand Up @@ -345,7 +379,7 @@ def from_preset_data(cls, data: Optional[Dict[str, Any]] = None, **kwargs) -> "P

# FIXME: collapse model version
model_name = data["model"] if "model" in data else ""
model = Model(model_name) if enum_contains(Model, model_name) else None
model = collapse_model(Model, model_name)

settings = data["parameters"] if "parameters" in data else {}

Expand Down

0 comments on commit 2b48641

Please sign in to comment.