diff --git a/README.md b/README.md index f93e1d7..fe70713 100644 --- a/README.md +++ b/README.md @@ -19,14 +19,14 @@ Download via [pip](https://pypi.org/project/novelai-api): pip install novelai-api ``` -A full list of examples is available in the [example](/example) directory +A full list of examples is available in the [example](example) directory The API works through the NovelAIAPI object. It is split in 2 groups: NovelAIAPI.low_level and NovelAIAPI.high_level ## low_level The low level interface is a strict implementation of the official API (). -It only checks for input types via assert and output schema if NovelAIAPI.low_level.is_schema_validation_enabled is True +It only checks for input types via assert, and output schema if NovelAIAPI.low_level.is_schema_validation_enabled is True ## high_level The high level interface builds on the low level one for easier handling of complex settings. diff --git a/novelai_api/Preset.py b/novelai_api/Preset.py index 8f2e339..d481633 100644 --- a/novelai_api/Preset.py +++ b/novelai_api/Preset.py @@ -34,7 +34,14 @@ class Order(IntEnum): } -def enum_contains(enum_class: EnumMeta, value) -> bool: +def enum_contains(enum_class: EnumMeta, value: str) -> bool: + """ + Check if the value provided is valid for the enum + + :param enum_class: Class of the Enum + :param value: Value to check + """ + if not hasattr(enum_class, "enum_member_values"): enum_class.enum_member_values = list(e.value for e in enum_class) @@ -45,6 +52,33 @@ def enum_contains(enum_class: EnumMeta, value) -> bool: return value in values +def _strip_model_version(value: str) -> str: + parts = value.split("-") + + if parts[-1].startswith("v") and parts[-1][1:].isdecimal(): + parts = parts[:-1] + + return "-".join(parts) + + +def collapse_model(enum_class: EnumMeta, value: str): + """ + Collapse multiple version of a model to the last model value + + :param enum_class: Class of the Enum + :param value: Value of the model to collapse + """ + + if not hasattr(enum_class, "enum_member_values"): + enum_class.enum_member_values = {_strip_model_version(e.value): e for e in enum_class} + + values = enum_class.enum_member_values + if len(values) == 0: + raise ValueError(f"Empty enum class: '{enum_class}'") + + return values.get(_strip_model_version(value)) + + class StrEnum(str, Enum): pass @@ -345,7 +379,7 @@ def from_preset_data(cls, data: Optional[Dict[str, Any]] = None, **kwargs) -> "P # FIXME: collapse model version model_name = data["model"] if "model" in data else "" - model = Model(model_name) if enum_contains(Model, model_name) else None + model = collapse_model(Model, model_name) settings = data["parameters"] if "parameters" in data else {}