Skip to content

Commit

Permalink
[cm] Categorical parameter implementation (#73)
Browse files Browse the repository at this point in the history
* [cm] Adding categorical parameter

* [cm] Updating categorical param to use NamedTuple

* [cm] Removing unnecessary check
  • Loading branch information
christhetree authored Oct 5, 2024
1 parent da7f2be commit 0e2be24
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 67 deletions.
8 changes: 4 additions & 4 deletions examples/example_clipper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn as nn
from torch import Tensor

from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, KnobNeutoneParameter
from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, ContinuousNeutoneParameter
from neutone_sdk.utils import save_neutone_model

logging.basicConfig()
Expand Down Expand Up @@ -59,9 +59,9 @@ def is_experimental(self) -> bool:

def get_neutone_parameters(self) -> List[NeutoneParameter]:
return [
KnobNeutoneParameter("min", "min clip threshold", default_value=0.15),
KnobNeutoneParameter("max", "max clip threshold", default_value=0.15),
KnobNeutoneParameter("gain", "scale clip threshold", default_value=1.0),
ContinuousNeutoneParameter("min", "min clip threshold", default_value=0.15),
ContinuousNeutoneParameter("max", "max clip threshold", default_value=0.15),
ContinuousNeutoneParameter("gain", "scale clip threshold", default_value=1.0),
]

@tr.jit.export
Expand Down
8 changes: 4 additions & 4 deletions examples/example_clipper_prefilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
from torch import Tensor

from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, KnobNeutoneParameter
from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, ContinuousNeutoneParameter
from neutone_sdk.filters import FIRFilter, FilterType
from neutone_sdk.utils import save_neutone_model

Expand Down Expand Up @@ -65,9 +65,9 @@ def is_experimental(self) -> bool:

def get_neutone_parameters(self) -> List[NeutoneParameter]:
return [
KnobNeutoneParameter("min", "min clip threshold", default_value=0.15),
KnobNeutoneParameter("max", "max clip threshold", default_value=0.15),
KnobNeutoneParameter("gain", "scale clip threshold", default_value=1.0),
ContinuousNeutoneParameter("min", "min clip threshold", default_value=0.15),
ContinuousNeutoneParameter("max", "max clip threshold", default_value=0.15),
ContinuousNeutoneParameter("gain", "scale clip threshold", default_value=1.0),
]

@tr.jit.export
Expand Down
8 changes: 4 additions & 4 deletions examples/example_overdrive-random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn
from torch import Tensor

from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, KnobNeutoneParameter
from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, ContinuousNeutoneParameter
from neutone_sdk.tcn_1d import FiLM
from neutone_sdk.utils import save_neutone_model

Expand Down Expand Up @@ -202,9 +202,9 @@ def get_citation(self) -> str:

def get_neutone_parameters(self) -> List[NeutoneParameter]:
return [
KnobNeutoneParameter("depth", "Effect Depth", 0.0),
KnobNeutoneParameter("P1", "Feature modulation 1", 0.0),
KnobNeutoneParameter("P2", "Feature modulation 2", 0.0),
ContinuousNeutoneParameter("depth", "Effect Depth", 0.0),
ContinuousNeutoneParameter("P1", "Feature modulation 1", 0.0),
ContinuousNeutoneParameter("P2", "Feature modulation 2", 0.0),
]

@torch.jit.export
Expand Down
10 changes: 5 additions & 5 deletions examples/example_rave.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torchaudio
from torch import Tensor

from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, KnobNeutoneParameter
from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, ContinuousNeutoneParameter
from neutone_sdk.audio import (
AudioSample,
AudioSamplePair,
Expand Down Expand Up @@ -60,20 +60,20 @@ def is_experimental(self) -> bool:

def get_neutone_parameters(self) -> List[NeutoneParameter]:
return [
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Chaos", description="Magnitude of latent noise", default_value=0.0
),
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Z edit index",
description="Index of latent dimension to edit",
default_value=0.0,
),
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Z scale",
description="Scale of latent variable",
default_value=0.5,
),
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Z offset",
description="Offset of latent variable",
default_value=0.5,
Expand Down
10 changes: 5 additions & 5 deletions examples/example_rave_prefilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torchaudio
from torch import Tensor, nn

from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, KnobNeutoneParameter
from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, ContinuousNeutoneParameter
from neutone_sdk.audio import (
AudioSample,
AudioSamplePair,
Expand Down Expand Up @@ -69,20 +69,20 @@ def is_experimental(self) -> bool:

def get_neutone_parameters(self) -> List[NeutoneParameter]:
return [
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Chaos", description="Magnitude of latent noise", default_value=0.0
),
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Z edit index",
description="Index of latent dimension to edit",
default_value=0.0,
),
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Z scale",
description="Scale of latent variable",
default_value=0.5,
),
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Z offset",
description="Offset of latent variable",
default_value=0.5,
Expand Down
10 changes: 5 additions & 5 deletions examples/example_rave_v1_prefilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torchaudio
from torch import Tensor, nn

from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, KnobNeutoneParameter
from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, ContinuousNeutoneParameter
from neutone_sdk.audio import (
AudioSample,
AudioSamplePair,
Expand Down Expand Up @@ -67,22 +67,22 @@ def is_experimental(self) -> bool:

def get_neutone_parameters(self) -> List[NeutoneParameter]:
return [
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Chaos",
description="Magnitude of latent noise",
default_value=0.0,
),
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Z edit index",
description="Index of latent dimension to edit",
default_value=0.0,
),
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Z scale",
description="Scale of latent variable",
default_value=0.5,
),
KnobNeutoneParameter(
ContinuousNeutoneParameter(
name="Z offset",
description="Offset of latent variable",
default_value=0.5,
Expand Down
8 changes: 4 additions & 4 deletions examples/example_spectral_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn as nn
from torch import Tensor

from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, KnobNeutoneParameter
from neutone_sdk import WaveformToWaveformBase, NeutoneParameter, ContinuousNeutoneParameter
from neutone_sdk.realtime_stft import RealtimeSTFT
from neutone_sdk.utils import save_neutone_model

Expand Down Expand Up @@ -171,11 +171,11 @@ def is_experimental(self) -> bool:

def get_neutone_parameters(self) -> List[NeutoneParameter]:
return [
KnobNeutoneParameter(
ContinuousNeutoneParameter(
"center", "center frequency of the filter", default_value=0.3
),
KnobNeutoneParameter("width", "width of the filter", default_value=0.5),
KnobNeutoneParameter(
ContinuousNeutoneParameter("width", "width of the filter", default_value=0.5),
ContinuousNeutoneParameter(
"amount", "spectral attenuation amount", default_value=0.9
),
]
Expand Down
2 changes: 2 additions & 0 deletions neutone_sdk/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
SDK_VERSION = "1.4.3"

MAX_N_PARAMS = 4
MAX_N_CATEGORICAL_VALUES = 20
MAX_N_CATEGORICAL_LABEL_CHARS = 20
MAX_N_AUDIO_SAMPLES = 3

DEFAULT_DAW_SR = 48000
Expand Down
2 changes: 1 addition & 1 deletion neutone_sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:

# Save parameter metadata
self.neutone_parameters_metadata = {
f"p{idx + 1}": p.to_metadata_dict()
f"p{idx + 1}": p.to_metadata()
for idx, p in enumerate(self.get_neutone_parameters())
}

Expand Down
19 changes: 4 additions & 15 deletions neutone_sdk/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@
"description": {"type": "string"},
"default_value": {"type": ["integer", "number", "string"]},
"used": {"type": "boolean"},
"type": {"type": "string", "enum": ["knob"]},
"max_n_chars": {"type": "integer", "minimum": -1},
"type": {"type": "string", "enum": ["continuous"]},
"max_n_chars": {"type": ["null", "integer"], "minimum": -1},
"n_values": {"type": ["null", "integer"], "minimum": 2},
"labels": {"type": ["null", "array"], "items": {"type": "string"}},
},
}
},
Expand Down Expand Up @@ -208,17 +210,4 @@ def validate_metadata(metadata: dict) -> bool:
AudioSample.from_b64(audio_sample_pair["in"])
AudioSample.from_b64(audio_sample_pair["out"])

# We shouldn't have any problems here but as a sanity check
for param_metadata in metadata["neutone_parameters"].values():
try:
if param_metadata["type"] == "knob":
assert (
0.0 <= param_metadata["default_value"] <= 1.0
), "Default values for continuous NeutoneParameters should be between 0 and 1"
except:
log.error(
f"Could not convert default_value to float for parameter {param_metadata.name} "
)
return False

return True
87 changes: 76 additions & 11 deletions neutone_sdk/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
import os
from abc import ABC
from enum import Enum
from typing import Union, NamedTuple, Dict
from typing import Union, NamedTuple, Optional, List

from neutone_sdk import constants

logging.basicConfig()
log = logging.getLogger(__name__)
log.setLevel(level=os.environ.get("LOGLEVEL", "INFO"))


class NeutoneParameterType(Enum):
KNOB = "knob"
CONTINUOUS = "continuous"
CATEGORICAL = "categorical"
TEXT = "text"


Expand All @@ -20,15 +23,17 @@ class ParameterMetadata(NamedTuple):
default_value: Union[int, float, str]
used: bool
type: str
max_n_chars: int = -1
max_n_chars: Optional[int] = None
n_values: Optional[int] = None
labels: Optional[List[str]] = None


class NeutoneParameter(ABC):
"""
Defines a Neutone Parameter abstract base class.
The name and the description of the parameter will be shown as a tooltip
within the UI. This parameter has no functionality.
within the UI. This parameter has no functionality and is meant to subclassed.
"""

def __init__(
Expand All @@ -45,7 +50,7 @@ def __init__(
self.used = used
self.type = param_type

def to_metadata_dict(self) -> ParameterMetadata:
def to_metadata(self) -> ParameterMetadata:
return ParameterMetadata(
name=self.name,
description=self.description,
Expand All @@ -55,13 +60,14 @@ def to_metadata_dict(self) -> ParameterMetadata:
)


class KnobNeutoneParameter(NeutoneParameter):
class ContinuousNeutoneParameter(NeutoneParameter):
"""
Defines a knob Neutone Parameter that the user can use to control a model.
Defines a continuous Neutone Parameter that the user can use to control a model.
The name and the description of the parameter will be shown as a tooltip
within the UI. `default_value` must be between 0 and 1 and will be used
as a default in the plugin when no presets are available.
within the UI.
`default_value` must be between 0 and 1 and will be used as a default in the plugin
when no presets are available.
"""

def __init__(
Expand All @@ -72,7 +78,66 @@ def __init__(
description,
default_value,
used,
NeutoneParameterType.KNOB,
NeutoneParameterType.CONTINUOUS,
)
assert (
0.0 <= default_value <= 1.0
), "`default_value` for continuous params must be between 0 and 1"


class CategoricalNeutoneParameter(NeutoneParameter):
"""
Defines a categorical Neutone Parameter that the user can use to control a model.
The name and the description of the parameter will be shown as a tooltip
within the UI.
`n_values` must be an int greater than or equal to 2 and less than or equal to
`constants.MAX_N_CATEGORICAL_VALUES`.
`default_value` must be in the range [0, `n_values` - 1].
`labels` is a list of strings that will be used as the labels for the parameter.
"""

def __init__(
self,
name: str,
description: str,
n_values: int,
default_value: int,
labels: Optional[List[str]] = None,
used: bool = True,
):
super().__init__(
name, description, default_value, used, NeutoneParameterType.CATEGORICAL
)
assert 2 <= n_values <= constants.MAX_N_CATEGORICAL_VALUES, (
f"`n_values` for categorical params must between 2 and "
f"{constants.MAX_N_CATEGORICAL_VALUES}"
)
assert (
0 <= default_value <= n_values - 1
), "`default_value` for categorical params must be between 0 and `n_values`-1"
self.n_values = n_values
if labels is None:
labels = [str(idx) for idx in range(n_values)]
else:
assert len(labels) == self.n_values, "labels must have `n_values` elements"
assert all(
len(label) < constants.MAX_N_CATEGORICAL_LABEL_CHARS for label in labels
), (
f"All labels must have length less than "
f"{constants.MAX_N_CATEGORICAL_LABEL_CHARS} characters"
)
self.labels = labels

def to_metadata(self) -> ParameterMetadata:
return ParameterMetadata(
name=self.name,
description=self.description,
default_value=self.default_value,
used=self.used,
type=self.type.value,
n_values=self.n_values,
labels=self.labels,
)


Expand Down Expand Up @@ -105,7 +170,7 @@ def __init__(
), "`default_value` must be a string of length less than `max_n_chars`"
self.max_n_chars = max_n_chars

def to_metadata_dict(self) -> ParameterMetadata:
def to_metadata(self) -> ParameterMetadata:
return ParameterMetadata(
name=self.name,
description=self.description,
Expand Down
Loading

0 comments on commit 0e2be24

Please sign in to comment.