Skip to content

Commit

Permalink
[API] Fix sampling options always being enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
Aedial committed Feb 7, 2024
1 parent 71b1336 commit 7fd3664
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
35 changes: 30 additions & 5 deletions novelai_api/Preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,33 @@ class Preset(metaclass=_PresetMetaclass):
name: str
#: Model the preset is for
model: Model
#: Enable state of sampling options
sampling_options: List[bool]

def __init__(self, name: str, model: Model, settings: Optional[Dict[str, Any]] = None):
object.__setattr__(self, "name", name)
object.__setattr__(self, "model", model)

object.__setattr__(self, "_settings", {})
self.update(settings)
self.set_sampling_options_state([True] * len(self._settings["order"]))

def set_sampling_options_state(self, sampling_options_state: List[bool]):
"""
Set the state (enabled/disabled) of the sampling options. Set it after setting the order setting.
It should come in the same order as the order setting.
"""

if "order" not in self._settings:
raise ValueError("The order setting must be set before setting the sampling options state")

if len(sampling_options_state) != len(self._settings["order"]):
raise ValueError(
"The length of the sampling options state list must be equal to the length "
"of the sampling options list"
)

object.__setattr__(self, "sampling_options", sampling_options_state)

def __setitem__(self, key: str, value: Any):
if key not in self._TYPE_MAPPING:
Expand Down Expand Up @@ -364,7 +384,8 @@ def __delattr__(self, name):

def __repr__(self) -> str:
model = self.model.value if self.model is not None else "<?>"
enabled_keys = ", ".join(f"{ORDER_TO_NAME[o]} = {o in self._settings['order']}" for o in Order)
enabled_order = [o for o, enabled in zip(self._settings["order"], self.sampling_options) if enabled]
enabled_keys = ", ".join(f"{ORDER_TO_NAME[o]} = {o in enabled_order}" for o in Order)

return f"Preset: '{self.name} ({model}, {enabled_keys})'"

Expand All @@ -380,7 +401,11 @@ def to_settings(self) -> Dict[str, Any]:

# remove disabled sampling options
if "order" in settings:
order = [Order(e) if isinstance(e, int) else e for e in settings["order"]]
order = [
(Order(o) if isinstance(o, int) else o)
for o, enabled in zip(settings["order"], self.sampling_options)
if enabled
]

for o in Order:
if o not in order:
Expand Down Expand Up @@ -416,8 +441,8 @@ def to_settings(self) -> Dict[str, Any]:
return settings

def __str__(self):
settings = {k: self._settings.get(k, v) for k, v in self.DEFAULTS.items()}
is_default = {k: " (default)" if v == self.DEFAULTS[k] else "" for k, v in settings.items()}
settings = self.to_settings() # use the sanitized settings
is_default = {k: " (default)" if v == self.DEFAULTS.get(k, None) else "" for k, v in settings.items()}

values = "\n".join(f" {k} = {v}{is_default[k]}" for k, v in settings.items())

Expand Down Expand Up @@ -475,7 +500,6 @@ def from_preset_data(cls, data: Optional[Dict[str, Any]] = None, **kwargs) -> "P

name = data["name"] if "name" in data else "<?>"

# FIXME: collapse model version
model_name = data["model"] if "model" in data else ""
model = collapse_model(Model, model_name)

Expand All @@ -490,6 +514,7 @@ def from_preset_data(cls, data: Optional[Dict[str, Any]] = None, **kwargs) -> "P
settings.pop("logit_bias_groups", None) # get rid of unsupported option

c = cls(name, model, settings)
c.set_sampling_options_state([o["enabled"] for o in order])

return c

Expand Down
2 changes: 1 addition & 1 deletion novelai_api/_low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from novelai_api.Tokenizer import Tokenizer
from novelai_api.utils import tokens_to_b64

PRINT_WITH_PARAMETERS = os.environ.get("NAI_PRINT", None)
PRINT_WITH_PARAMETERS = os.environ.get("NAI_PRINT", False)


# === INTERNALS === #
Expand Down

0 comments on commit 7fd3664

Please sign in to comment.