Skip to content

Commit

Permalink
CLN/ENH: Rename and refactor datapipes, add datasets; fix 574 724 754 (
Browse files Browse the repository at this point in the history
…#755)

* Rename vak/datasets -> vak/datapipes

* Rename frame_classifcation.window_dataset.WindowDataset -> TrainDatapipe

* Rename frame_classification/window_dataset.py -> train_datapipe.py

* Fix WindowDataset -> TrainDatapipe in docstrings

* Rename frame_classification.frames_dataset.FramesDataset -> infer_datapipe.InferDatapipe

* Rename transforms.StandardizeSpect -> FramesStandarizer

* Import FramesStandarizer in datapipes/frame_classification/infer_datapipe.py

* Add module-level docstring in vak/datapipes/__init__.py

* Rewrite transforms.defaults.frames_classification.EvalItemTransform and PredictItemTransform as a single class, InferItemTransform, and remname spect_standardizer -> frames_standardizer in that module

* Fix bug in view_as_window_batch so it works on 1-D arrays, add type hinting in src/vak/transforms/functional.py

* Change frame_labels_transform in InferItemTransform to be a torchvision.transforms.Compose, so we get back a windowed batch

* Remove TODO in src/vak/models/frame_classification_model.py

* Rewrite TrainDatapipe to always use TrainItemTransform, add parameters that get passed to TrainItemTransform when instatiating it inside TrainDatapipe.__init__

* Rewrite frames_classification.InferDatapipe to always use transforms.default.frame_classification.InferItemTransform, add parameters that get passed to InferItemTransform when instatiating it inside InferDatapipe.__init__

* Rewrite train.frame_classification to pass kwargs into datapipes that now use default transforms, and no longer call transforms.defaults.get

* Rewrite predict.frame_classification to pass kwargs into datapipes that now use default transforms, and no longer call transforms.defaults.get

* Rewrite eval.frame_classification to pass kwargs into datapipes that now use default transforms, and no longer call transforms.defaults.get

* Rewrite predict.frame_classification to pass kwargs into datapipes that now use default transforms, and no longer call transforms.defaults.get

* Rename 'spect_scaler_path' -> 'frames_standardizer_path'

* Rename 'normalize_spectrogram' -> 'standardize_frames'

* Fix 'SpectScaler' -> 'FramesStandardizer', 'normalize spectrogram' -> 'standardize (normalize) frames'

* Fix 'SpectScaler' -> 'FramesStandardizer' in tests/

* Fix key names in doc/toml

* Add missing comma in src/vak/train/frame_classification.py

* Rename config/valid-version-1.1.toml -> valid-version-1.2.toml

* Fix normalize spectrograms -> standardize frames more places in docs

* Fix datapipes.frame_classification.InferDatapipe to have needed parameters for item transform

* Fix datapipes.frame_classification.TrainDatapipe to have needed parameters for item transform

* Fix arg name 'spect_standardizer -> frames_standardizer in src/vak/train/frame_classification.py

* fixup fix TrainDatapipe parameters

* Fix variable name in src/vak/datapipes/frame_classification/train_datapipe.py

* Add missing arg return_padding_mask in src/vak/train/frame_classification.py

* Fix transforms.default.frame_classification.InferItemTransform to not window frame labels, just convert them to LongTensor

* Revise docstring in eval/frame_classification

* Remove item_transform from docstring in datapipes/frame_classification/train_datapipe.py

* Add return_padding_mask arg in vak/predict/frame_classification.py

* Remove src/vak/transforms/defaults/parametric_umap.py

* Rename/rewrite Datapipe class for ParametricUMAP, hard-code in transform

* Remove transforms/defaults/get.py, remove related imports in transforms/defaults/__init__.py

* Finish removing transform fetching for ParametricUMAP

* Fix typo in src/vak/eval/frame_classification.py

* Fix "StandardizeSpect" -> "FramesStandardizer" in src/vak/learncurve/frame_classification.py

* Apply changes from nox lint session

* Make flake8 fixes, remove unused function get_default_frame_classification_transform

* Fix "StandardizeSpect" -> "FramesStandardizer" in tests/scripts/vaktestdata/configs.py"

* WIP: Add datasets/ with biosoundsegbench

* Renam tests/test_datasets -> test_datapipes, fix tests

* Fix 'StandardizeSpect' -> 'FramesStandardizer' in two tests

* Remove two uses of vak.transforms.defaults.get_default_transform from tests

* Fix datapipe used in tests/test_models/test_parametric_umap_model.py

* Use TYPE_CHECKING to avoid circular import in src/vak/datapipes/frame_classification/infer_datapipe.py

* Add method 'fit_inputs_targets_csv_path' to FramesStandardizer, rewrite 'fit_dataset_path' method to just call this new method

* fixup add method

* Add unit test for FramesStandardizer.fit_inputs_targets_csv_path

* Remove unused import from src/vak/transforms/transforms.py

* Remove unused import in src/vak/transforms/defaults/frame_classification.py

* Pep8 fix in src/vak/datasets/__init__.py

* Apply linting to src/vak/transforms/transforms.py

* Correct docstring in src/vak/transforms/defaults/frame_classification.py

* Import datasets in src/vak/__init__.py

* Rename datapipes/frame_classification/constants.FRAME_LABELS_EXT -> MULTI_FRAME_LABELS_EXT, and change value to 'multi-frame-labels.npy', and change value of FRAME_LABELS_NPY_PATH_COL_NAME to 'multi_frame_labels_npy_path'

* Rename vak.datapipes.frame_classification.constants.FRAME_LABELS_NPY_PATH_COL_NAME -> MULTI_FRAME_LABELS_PATH_COL_NAME

* Rename key in item returned by frame_classification.TrainItemTransform and InferItemTransform; 'frame_labels' -> 'multi_frame_labels'

* WIP: Get BioSoundSegBench class working

* Rewrite FrameClassificationModel to handle different target types

* Add VALID_SPLITS to common.constants

* In datasets/biosoundsegbench.py: change VALID_TARGET_TYPES to be the ones we're using for experiments right now, fix TrainItemTransform to handle target types, clean up __init__ method validation

* Add initial unit tests for BioSoundSegBench dataset

* Add helper function vak.datasets.get

* Clean up how we validate target_type in datasets.BioSoundSegBench.__init__

* Add tests/test_datasets/__init__.py (to make a sub-package)

* Add initial unit tests for vak.datasets.get

* Modify BioSoundSegBench.__init__ so we can write splits_path as just the filename

* Use expanded_user_path converter on path and splits_path attributes of DatasetConfig

* Rename BOUNDARY_ONEHOT_PATH_COL_NAME -> BOUNDARY_FRAME_LABELS_PATH_COL_NAME in datasets/biosoundsegbench.py

* Modify datasets.BioSoundSegBench to compute metadata from splits_json path

* Fix mock_biosoundsegbench_dataset fixture so mocked files follow naming conventions of dataset

* Modify mock_biosoundsegbench_dataset fixture to save labelmaps.json

* Change BioSoundSegBench.__init__ so we have training_replicate_metadata attribute, frame_dur attribute, and labelmap attribute

* Add DATASETS dict in dataset/__init__.py, used by vak.datasets.get to look up class (value) by name (key)

* Use vak.datasets.DATASETS in vak.datasets.get to get class

* Rewrite BioSoundSegBench.__init__ so we can either pass in a FramesStandardizer instance or tell it to fit a new one to the specified split, that then gets added to the transform

* Import DATASETS inside vak.datasets.get to avoid circular import

* Make fixes in datasets/biosoundsegbench.py: import FramesStandardizer inside TrainItemTransform.__init__, fix tmp_splits_path -> splits-jsons (plural), add needed __len__ method to class

* Rename BioSoundSegBench property 'input_shape' -> 'shape' for consistency with frame_classification datapipes

* Get vak/train/frame_classification.py to the point where it runs

* Add missing self in BioSoundSegBench._getitemval

* Rewrite src/vak/eval/frame_classification.py to work with built-in datasets, and remove 'split' parameter from eval_frame_classification_model function -- check if 'split' is in dataset_config and if not, default to 'test'

* Remove split argument in call to eval_frame_classification_model inside src/vak/learncurve/frame_classification.py

* Remove split parameter from eval._eval.eval -- it's not an attribute of EvalConfig and we can now pass in a 'split' through dataset_config

* Remove 'split' parameter from eval_parametric_umap_model, check if 'split' in dataset_config and if not default to 'test'

* Rewrite src/vak/predict/frame_classification.py to work with built-in datasets; check if 'split' is in dataset_config and if not, default to 'predict'

* Add comments to structure src/vak/train/frame_classification.py

* Fix how we check for key in src/vak/predict/frame_classification.py

* Fix how we check for key in dict in src/vak/eval/parametric_umap.py

* Fix how we check for key in dict in src/vak/eval/frame_classification.py

* Fix unit tests in test_dataset.py: assert that path attributes are vak.converters.expanded_user_path(value from config), not pathlib.Path

* Fix how we parametrize tests/test_dataset/test_get.py

* In BioSoundSegBench.__init__, fix how we calculate frame_dur and how we set labelmap attribute for binary/boundary frame labels

* In FrameClassificationModel.validation_step, convert Levenshtein distance to float to squelch warning from Lightning

* Fix FrameClassificationModel so train/val with multi-class + boundary labels works

* Fix vak.cli.predict to not assume that config has a prep attribute

* Fix how we override default split with a split from dataset_config['params'] in predict/frame_classification and eval/frame_classification

* Change BioSoundSegBench so __getitem__ can return 'frames_path' in 'item' for eval/predict

* In predict.frame_classification, set 'return_frames_path' to True in dataset_config['params'] since we need this for predictions

* Add constant DEFAULT_SPECT_FORMAT in common.constants

* Fix SPECT_KEY -> TIMEBINS_KEY in cli.prep

* Fix how we determine input_type and spect_format for built-in datasets in predict/frame_classification

* Add nn/loss/crossentropy.py, wraps torch.nn.CrossEntropy, but converts weight arg as list to tensor

* Fixup add loss

* Use nn.loss.CrossEntropy with TweetyNet model

* Clean up prediction_step in FrameClassificationModel

* Get predict working for multi_frame_labels and boundary_frame_labels, still need to test binary_frame_labels and (boundary, multi)

* Rename 'unlabeled_label' -> 'background_label' in transforms/frame_labels

* Rename 'unlabeled_label' -> 'background_label' in tests/test_transforms/test_frame_labels

* Rewrite transforms/frame_labels/functional.py to handle boundary labels

- Add `boundary_labels_to_segment_inds_list' that finds segment indexing arrays from a list of boundary labels
- Rename `to_segment_inds` -> `frame_labels_to_segment_inds_list
- Have `preprocess` optionally take `boundary_labels` and use it to find segments, instead of frame labels
- Fix type annotations to use npt.NDArray instead of np.ndarray

* Change how FrameClassificationModel calls loss for multi-class + boundary targets -- assume we pass to an instance of a loss function, and get back either a scalar loss or a dict mapping loss names to scalar values

* Change arg name 'unlabeled_label' -> 'background_label' in prep/frame_classification/make_splits.py

* Fix predict.frame_classification for multi-class, and add logic for multi-class frame labels with boundary frame labels

* Add DEFAULT_BACKGROUND_LABEL to common.constants

* Use DEFAULT_BACKGROUND_LABEL in transforms.frame_labels.functional

* Rename unlabeled -> background_label in common.labels

* Add background_label in docstring in common/labels.py

* Add 'background_label' to FrameClassificationModel, defaults to common.constants.DEFAULT_BACKGROUND_LABEL, used to validate length of string labels in labelmap

* Fix 'unlabeled' -> common.constants.DEFAULT_BACKGROUND_LABEL in anohter place in common/labels.py

* Fix unlabeled -> background label in docstrings in transforms

* Use 'background_label' argument in place of magic string 'unlabeled' in prep/frame_classification/learncurve.py

* Fix unlabeled -> background label in docstrings in transforms/frame_labels/functional.py

* Add background_label to docstring in src/vak/prep/frame_classification/learncurve.py

* Add background_label to function in src/vak/prep/frame_classification/make_splits.py

* Add background_label parameter to src/vak/predict/frame_classification.py and add type annotations to function signature

* Fix unlabeled -> background / vak.common.constants.DEFAULT_BACKGROUND_LABEL in tests

* Fix 'map_unlabeled' -> 'map_background' in tests/

* Fix 'constants' -> 'common' in src/vak/models/frame_classification_model.py

* Fix arg name map_unlabeled -> map_background

* Fix arg name map_unlabeled -> map_background in prep/parametric_umap

* Fix 'unlabeled' -> vak.common.constants.DEFAULT_BACKGROUND_LABEL in tests/

* Fix name `to_inds_list` -> segment_inds_list_from_class_labels` in test_transforms/test_frame_labels/test_functional.py
  • Loading branch information
NickleDave authored May 11, 2024
1 parent 2c6e469 commit 5003113
Show file tree
Hide file tree
Showing 118 changed files with 3,060 additions and 1,363 deletions.
14 changes: 14 additions & 0 deletions doc/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,20 @@ The :mod:`vak.datasets` module contains datasets built into vak.
datasets.frame_classification
datasets.parametric_umap

Datapipes
---------

The :mod:`vak.datapipes` module contains datapipes for loading dataset
generated by :func:`vak.prep.prep`.

.. autosummary::
:toctree: generated
:template: module.rst
:recursive:

datapipes.frame_classification
datapipes.parametric_umap

Metrics
-------
The :mod:`vak.metrics` module contains metrics used
Expand Down
4 changes: 2 additions & 2 deletions doc/toml/gy6or6_eval.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/che
# labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations;
# this is used when generating predictions
labelmap_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/labelmap.json"
# spect_scaler_path: path to file containing SpectScaler that was fit to training set
# frames_standardizer_path: path to file containing SpectScaler that was fit to training set
# We want to transform the data we predict on in the exact same way
spect_scaler_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect"
frames_standardizer_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect"
# batch_size
# for predictions with a frame classification model, this should always be 1
# and will be ignored if it's not
Expand Down
4 changes: 2 additions & 2 deletions doc/toml/gy6or6_predict.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/che
# labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations;
# this is used when generating predictions
labelmap_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/labelmap.json"
# spect_scaler_path: path to file containing SpectScaler that was fit to training set
# frames_standardizer_path: path to file containing SpectScaler that was fit to training set
# We want to transform the data we predict on in the exact same way
spect_scaler_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect"
frames_standardizer_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect"
# batch_size
# for predictions with a frame classification model, this should always be 1
# and will be ignored if it's not
Expand Down
4 changes: 2 additions & 2 deletions doc/toml/gy6or6_train.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ root_results_dir = "/PATH/TO/FOLDER/results/train"
batch_size = 8
# num_epochs: number of training epochs, where an epoch is one iteration through all samples in training split
num_epochs = 2
# normalize_spectrograms: if true, normalize spectrograms per frequency bin, so mean of each is 0.0 and std is 1.0
# standardize_frames: if true, standardize (normalize) frames (input to neural network) per frequency bin, so mean of each is 0.0 and std is 1.0
# across the entire training split
normalize_spectrograms = true
standardize_frames = true
# val_step: step number on which to compute metrics with validation set, every time step % val_step == 0
# (a step is one batch fed through the network)
# saves a checkpoint if the monitored evaluation metric improves (which is model specific)
Expand Down
2 changes: 2 additions & 0 deletions src/vak/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
cli,
common,
config,
datapipes,
datasets,
eval,
learncurve,
Expand Down Expand Up @@ -42,6 +43,7 @@
"cli",
"common",
"config",
"datapipes",
"datasets",
"eval",
"learncurve",
Expand Down
2 changes: 1 addition & 1 deletion src/vak/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ def eval(toml_path: str | pathlib.Path) -> None:
output_dir=cfg.eval.output_dir,
num_workers=cfg.eval.num_workers,
batch_size=cfg.eval.batch_size,
spect_scaler_path=cfg.eval.spect_scaler_path,
frames_standardizer_path=cfg.eval.frames_standardizer_path,
post_tfm_kwargs=cfg.eval.post_tfm_kwargs,
)
2 changes: 1 addition & 1 deletion src/vak/cli/learncurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def learning_curve(toml_path):
num_workers=cfg.learncurve.num_workers,
results_path=results_path,
post_tfm_kwargs=cfg.learncurve.post_tfm_kwargs,
normalize_spectrograms=cfg.learncurve.normalize_spectrograms,
standardize_frames=cfg.learncurve.standardize_frames,
shuffle=cfg.learncurve.shuffle,
val_step=cfg.learncurve.val_step,
ckpt_step=cfg.learncurve.ckpt_step,
Expand Down
8 changes: 4 additions & 4 deletions src/vak/cli/predict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from pathlib import Path

from .. import config
from .. import common, config
from .. import predict as predict_module
from ..common.logging import config_logging_for_cli, log_version

Expand Down Expand Up @@ -33,7 +33,7 @@ def predict(toml_path):
force=True,
)
log_version(logger)
logger.info("Logging results to {}".format(cfg.prep.output_dir))
logger.info("Logging results to {}".format(cfg.predict.output_dir))

if cfg.predict.dataset.path is None:
raise ValueError(
Expand All @@ -49,8 +49,8 @@ def predict(toml_path):
checkpoint_path=cfg.predict.checkpoint_path,
labelmap_path=cfg.predict.labelmap_path,
num_workers=cfg.predict.num_workers,
timebins_key=cfg.prep.spect_params.timebins_key,
spect_scaler_path=cfg.predict.spect_scaler_path,
timebins_key=cfg.prep.spect_params.timebins_key if cfg.prep else common.constants.TIMEBINS_KEY,
frames_standardizer_path=cfg.predict.frames_standardizer_path,
annot_csv_filename=cfg.predict.annot_csv_filename,
output_dir=cfg.predict.output_dir,
min_segment_dur=cfg.predict.min_segment_dur,
Expand Down
4 changes: 2 additions & 2 deletions src/vak/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def train(toml_path):
num_epochs=cfg.train.num_epochs,
num_workers=cfg.train.num_workers,
checkpoint_path=cfg.train.checkpoint_path,
spect_scaler_path=cfg.train.spect_scaler_path,
frames_standardizer_path=cfg.train.frames_standardizer_path,
results_path=results_path,
normalize_spectrograms=cfg.train.normalize_spectrograms,
standardize_frames=cfg.train.standardize_frames,
shuffle=cfg.train.shuffle,
val_step=cfg.train.val_step,
ckpt_step=cfg.train.ckpt_step,
Expand Down
2 changes: 1 addition & 1 deletion src/vak/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
If a helper/utility function is only used in one module,
it should live either in that module or another at the same level.
See for example :mod:`vak.prep.prep_helper` or
:mod:`vak.datsets.window_dataset._helper`.
:mod:`vak.datsets.train_datapipe._helper`.
"""

from . import (
Expand Down
7 changes: 6 additions & 1 deletion src/vak/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""constants used by multiple modules.
"""Constants used by multiple modules.
Defined here to avoid circular imports.
"""

Expand Down Expand Up @@ -26,6 +26,7 @@
"npz": np.load,
}
VALID_SPECT_FORMATS = list(SPECT_FORMAT_LOAD_FUNCTION_MAP.keys())
DEFAULT_SPECT_FORMAT = "npz"

# ---- valid types of training data, the $x$ that goes into a network
VALID_X_SOURCES = {"audio", "spect"}
Expand Down Expand Up @@ -57,3 +58,7 @@
"npz": SPECT_NPZ_EXTENSION,
"mat": ".mat",
}

VALID_SPLITS = ("predict", "test", "train", "val")

DEFAULT_BACKGROUND_LABEL = "background"
43 changes: 28 additions & 15 deletions src/vak/common/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import numpy as np
import pandas as pd

from . import annotation
from . import annotation, constants


def to_map(labelset: set, map_unlabeled: bool = True) -> dict:
def to_map(
labelset: set, map_background: bool = True, background_label: str = constants.DEFAULT_BACKGROUND_LABEL
) -> dict:
"""Convert set of labels to `dict`
mapping those labels to a series of consecutive integers
from 0 to n inclusive,
Expand All @@ -18,21 +20,31 @@ def to_map(labelset: set, map_unlabeled: bool = True) -> dict:
from annotations of a vocalization into
a label for every time bin in a spectrogram of that vocalization.
If ``map_unlabeled`` is True, then the label 'unlabeled'
will be added to labelset, and will map to 0,
If ``map_background`` is True, then a label
will be added to labelset representing a background class
(any segment that is not labeled).
The default for this label is
:const:`vak.common.constants.DEFAULT_BACKGROUND_LABEL`.
This string label will map to class index 0,
so the total number of classes is n + 1.
Parameters
----------
labelset : set
Set of labels used to annotate a dataset.
map_unlabeled : bool
If True, include key 'unlabeled' in mapping.
map_background : bool
If True, include key specified by
``background_label`` in mapping.
Any time bins in a spectrogram
that do not have a label associated with them,
e.g. a silent gap between vocalizations,
will be assigned the integer
that the 'unlabeled' key maps to.
that the background key maps to.
background_label: str, optional
The string label applied to segments belonging to the
background class.
Default is
:const:`vak.common.constants.DEFAULT_BACKGROUND_LABEL`.
Returns
-------
Expand All @@ -45,11 +57,12 @@ def to_map(labelset: set, map_unlabeled: bool = True) -> dict:
)

labellist = []
if map_unlabeled is True:
labellist.append("unlabeled")

if map_background is True:
# NOTE we append background label *first*
labellist.append(background_label)
# **then** extend with the rest of the labels
labellist.extend(sorted(list(labelset)))

# so that background_label maps to class index 0 by default in next line
labelmap = dict(zip(labellist, range(len(labellist))))
return labelmap

Expand Down Expand Up @@ -124,7 +137,7 @@ def from_df(

# added to fix https://github.com/NickleDave/vak/issues/373
def multi_char_labels_to_single_char(
labelmap: dict, skip: tuple[str] = ("unlabeled",)
labelmap: dict, skip: tuple[str] = (constants.DEFAULT_BACKGROUND_LABEL,)
) -> dict:
"""Return a copy of a ``labelmap`` where any
labels that are strings with multiple characters
Expand All @@ -146,9 +159,9 @@ def multi_char_labels_to_single_char(
to integers. As returned by
``vak.labels.to_map``.
skip : tuple
Of strings, labels to leave
as multiple characters.
Default is ('unlabeled',).
A tuple of labels to leave as multiple characters.
Default is a tuple containing just
:const:`vak.common.constants.DEFAULT_BACKGROUND_LABEL`.
Returns
-------
Expand Down
1 change: 0 additions & 1 deletion src/vak/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from .train import TrainConfig
from .trainer import TrainerConfig


__all__ = [
"config",
"dataset",
Expand Down
6 changes: 4 additions & 2 deletions src/vak/config/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import attr.validators
from attr import asdict, define, field

from ..common.converters import expanded_user_path


@define
class DatasetConfig:
Expand All @@ -31,9 +33,9 @@ class DatasetConfig:
Default is None.
"""

path: pathlib.Path = field(converter=pathlib.Path)
path: pathlib.Path = field(converter=expanded_user_path)
splits_path: pathlib.Path | None = field(
converter=attr.converters.optional(pathlib.Path), default=None
converter=attr.converters.optional(expanded_user_path), default=None
)
name: str | None = field(
converter=attr.converters.optional(str), default=None
Expand Down
6 changes: 3 additions & 3 deletions src/vak/config/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ class EvalConfig:
Argument to torch.DataLoader. Default is 2.
labelmap_path : str
path to 'labelmap.json' file.
spect_scaler_path : str
path to a saved SpectScaler object used to normalize spectrograms.
frames_standardizer_path : str
path to a saved :class:`vak.transforms.FramesStandardizer` object used to standardize (normalize) frames.
If spectrograms were normalized and this is not provided, will give
incorrect results.
post_tfm_kwargs : dict
Expand Down Expand Up @@ -152,7 +152,7 @@ class EvalConfig:
converter=converters.optional(expanded_user_path), default=None
)
# optional, transform
spect_scaler_path = field(
frames_standardizer_path = field(
converter=converters.optional(expanded_user_path),
default=None,
)
Expand Down
17 changes: 12 additions & 5 deletions src/vak/config/learncurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from .train import TrainConfig
from .trainer import TrainerConfig

REQUIRED_KEYS = ("dataset", "model", "root_results_dir", "trainer",)
REQUIRED_KEYS = (
"dataset",
"model",
"root_results_dir",
"trainer",
)


@define
Expand Down Expand Up @@ -45,9 +50,9 @@ class LearncurveConfig(TrainConfig):
Argument to torch.DataLoader.
shuffle: bool
if True, shuffle training data before each epoch. Default is True.
normalize_spectrograms : bool
if True, use spect.utils.data.SpectScaler to normalize the spectrograms.
Normalization is done by subtracting off the mean for each frequency bin
standardize_frames : bool
if True, use :class:`vak.transforms.FramesStandardizer` to standardize the frames.
Normalization is done by subtracting off the mean for each row
of the training set and then dividing by the std for that frequency bin.
This same normalization is then applied to validation + test data.
val_step : int
Expand Down Expand Up @@ -75,6 +80,7 @@ class LearncurveConfig(TrainConfig):
See the docstring of the transform for more details on
these arguments and how they work.
"""

post_tfm_kwargs = field(
validator=validators.optional(are_valid_post_tfm_kwargs),
converter=converters.optional(convert_post_tfm_kwargs),
Expand All @@ -91,7 +97,8 @@ def from_config_dict(cls, config_dict: dict) -> LearncurveConfig:
by loading a valid configuration toml file with
:func:`vak.config.parse.from_toml_path`,
and then using key ``learncurve``,
i.e., ``LearncurveConfig.from_config_dict(config_dict['learncurve'])``."""
i.e., ``LearncurveConfig.from_config_dict(config_dict['learncurve'])``.
"""
for required_key in REQUIRED_KEYS:
if required_key not in config_dict:
raise KeyError(
Expand Down
7 changes: 3 additions & 4 deletions src/vak/config/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from .model import ModelConfig
from .trainer import TrainerConfig


REQUIRED_KEYS = (
"checkpoint_path",
"dataset",
Expand Down Expand Up @@ -50,8 +49,8 @@ class PredictConfig:
num_workers : int
Number of processes to use for parallel loading of data.
Argument to torch.DataLoader. Default is 2.
spect_scaler_path : str
path to a saved SpectScaler object used to normalize spectrograms.
frames_standardizer_path : str
path to a saved :class:`vak.transforms.FramesStandardizer` object used to standardize (normalize) frames.
If spectrograms were normalized and this is not provided, will give
incorrect results.
annot_csv_filename : str
Expand Down Expand Up @@ -104,7 +103,7 @@ class PredictConfig:
)

# optional, transform
spect_scaler_path = field(
frames_standardizer_path = field(
converter=converters.optional(expanded_user_path),
default=None,
)
Expand Down
Loading

0 comments on commit 5003113

Please sign in to comment.