From df876e3920515feb9fa36d8f280f34b2ea392ba3 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 23:22:55 -0400 Subject: [PATCH] ENH: Switch to version 1.0 of config file format, fix #685 #345 #748 (#750) * WIP: Add src/vak/config/dataset.py * Add module-level docstring + type annotations in src/vak/config/parse.py * WIP: Fix how cli.prep adds dataset path to toml config file * Change table names in src/vak/config/valid.toml * Rename section -> table in config/parse.py * In cli/prep change 'section' -> 'table' and lowercase table names * In config/config.py, change 'section' -> 'table' and lowercase table names * Change '[PREP]' -> '[vak.prep]' in config/prep.py * WIP: Change table names in config files in tests/data_for_tests/configs * Make tomlkit a dependency in pyproject.toml, drop toml * Change config/parse.py to use tomlkit * Update example configs in doc/toml/ * Add link to example config files in docs, in error messages in config/validators.py * Remove 'spect_params' from REQUIRED_OPTIONS in config/parse.py, this is not a top-level table and will be an attribute of prep instead * Rename 'config_toml' -> 'config_dict' in config/parse.py * Fix function _validate_tables_arg_convert_list in config/parse.py * Fix error message formatting in src/vak/config/validators.py * Add ModelConfig class to config/model.py, add type annotations, fix config_from_toml_dict to look in specific section * Fixup fixing config_from_toml_dict to look in specific section * Rewrite config/eval.py with 'modern' attrs * Fixup rewrite config/eval with 'modern attrs * Rewrite config/learncurve.py with 'modern' attrs * Rewrite config/predict.py with 'modern' attrs * Rewrite config/prep.py with 'modern' attrs * Rewrite config/train.py with 'modern' attrs * Rename Dataset -> DatasetConfig in config/dataset.py * Add are_table_options_valid to config/validators.py, will be used by classmethods from_config_dict * WIP: Add from_config_dict classmethod to EvalConfig * WIP: Add tests/test_config/test_dataset.py * Make fixes to ModelConfig class, fix circular imports in config/model.py module * Write tests in tests/test_config/test_dataset.py * Use tomlkit not toml in cli/prep.py * Use tomlkit in tests/fixtures/annot.py * Use tomlkit in tests/scripts/vaktestdata/configs.py * Use tomlkit in tests/scripts/vaktestdata/source_files.py * Use tomlkit in tests/test_config/test_validators.py * Remove spect_params attribute from Config in config/config.py, fix class' docstring * Reorder attributes, fix typo in docstring of DatasetConfig * Rewrite config/parse.py assuming config classes have from_config_dict classmethod * Rename `table` -> `table_name` in a couple validators in config/validators.py * Remove use of config.model.config_from_toml_path in cli/eval.py * Remove use of config.model.config_from_toml_path in cli/learncurve.py * Remove use of config.model.config_from_toml_path in cli/predict.py * Remove use of config.model.config_from_toml_path in cli/train.py * Remove functions from config/model.py: config_from_toml_path and config_from_toml_dict * Add `to_dict` method to ModelConfig * Use to_dict() method of ModelConfig class in cli functions * Fix how we get labelset from config in tests/fixtures/annot.py * WIP: Clean up / rewrite tests/fixtures/config.py * Fix model tables in tests/data_for_tests/configs * Finish unit tests in tests/test_config/test_model.py * Fix model tables in doc/toml * Rename data_for_tests/configs/invalid_option_config.toml -> invalid_key_config.toml * Rename are_options_valid/are_table_options_valid -> are_keys_valid/are_table_keys_valid in config/validators.py * Rename two fixtures in fixtures/config.py: invalid_section_config_path -> invalid_table_config_path, invalid_option_config_path -> invalid_key_config_path * Fix validator names in config/parse.py, rename TABLE_CLASSES constant -> TABLE_CLASSES_MAP * Rename config/valid.toml -> valid-version-1.0.toml, fix how model table is declared * Fix VALID_TOML_PATH in config/validators.py after renaming config/valid.toml -> config/valid-version-1.0.toml * Import config classes in vak/config/__init__.py * Add _tomlkit_to_popo to tests/fixtures/config.py so we operate on dicts not tomlkit.TOMLDocument * Add _tomlkit_to_popo to config/parse.py so we operate on dicts not tomlkit.TOMLDocument * Finish rewriting tests for tests/test_config/test_prep.py * Rewrite EvalConfig with from_config_dict method * Rewrite LearncurveConfig with from_config_dict method * Rewrite PredictConfig with from_config_dict method * Rewrite PrepConfig with from_config_dict method * Rewrite TrainConfig with from_config_dict method * Remove functions from config/parse.py * Rename config/parse.py -> config/load.py * Make functions in config/parse.py into classmethods on Config class * Use config.Config.from_toml_path everywhere instead of config.parse.from_toml_path * Make fixes in Config classmethods * Change load._load_toml_from_path again so that it returns config_dict['vak'], to avoid writing ['vak'] everywhere in calling functions * Add docstring to are_tables_valid in config/validators.py * Lowercase config table names in tests/scripts/vaktestdata/configs.py * In tests/scripts/vaktestdata/source_files.py, change cfg.spect_params -> cfg.prep.spect_params, fix how we change values in toml, add tables_to_parse arg to call to Config.from_toml_path * in test_cli/test_prep.py, call vak.config.load not vak.config.parse * Fix how we instantiate DatasetConfig and ModelConfig in EvalConfig.from_config_dict method * Fix how we instantiate DatasetConfig and ModelConfig in PredictConfig.from_config_dict method * Fix how we instantiate DatasetConfig and ModelConfig in TrainConfig.from_config_dict method * Fix how we instantiate DatasetConfig and ModelConfig in LearncurveConfig.from_config_dict method * Remove brekapoint in src/vak/config/model.py * Fix wrong variable name so we save configs correctly in tests/scripts/vaktestdata/source_files.py, and add tables_to_parse arg to Config.from_toml_path, so we don't get 'missing dataset' errors * Fix how we re-write configs, in tests/scripts/vaktestdata/configs.py * Add model and dataset tables to get those keys in top-level tables, in src/vak/config/valid-version-1.0.toml * Change cfg.table.dataset_path -> cfg.table.dataset.path in vak/cli modules (e.g., vak.train.dataset.path) * Get tests passing for tests/test_config/test_eval.py * Clean up tests/test_config/test_eval.py * Get tests passing in tests/test_config/test_predict.py * Fix how we access config_toml in tests/scripts/vaktestdata/configs.py -- missing 'vak' key * Add pytest.mark.parametrize to tests/test_config/test_learncurve.py * Rewrite tests in tests/test_config/test_train.py * Rewrite tests in tests/test_config/test_config.py * Add unit test to tests/test_config/test_model.py * Add unit test for exceptions in tests/test_config/test_eval.py * Fix 'cfg.spect_params' -> 'cfg.prep.spect_params' in src/vak/cli/predict.py * Add unit test for exceptions in tests/test_config/test_learncurve.py * Add unit test for exceptions in tests/test_config/test_train.py * Add more test cases to TestEvalConfig.test_from_config_dict_raises * Add more test cases to TestLearncurveConfig.test_from_config_dict_raises * Add unit test for exceptions in tests/test_config/test_predict.py * Add two unit tests that PrepConfig raises expected exceptions * Fix/add unit tests in tests/test_config/test_config.py * Change order of parameters for Config.from_config_dict, make toml_path last param * Fix/add unit tests in tests/fixtures/config.py * Fix/add unit tests in tests/fixtures/config.py * Rename test_config/test_parse.py -> test_load.py, fix/rewrite tests * Fix tests in tests/test_config/test_spect_params.py * Make fixups in tests/test_config * Apply fixes from linter * Make more linting fixes * Speed up install in nox session 'lint', only install linting tools * Change names 'section'/'option' -> 'table'/'key' in tests * Fix tests in tests/test_cli/test_eval.py * Finish fixing cli tests, fix renaming * Fix how we get 'path' from 'dataset' table in configs, in tests/fixtures/csv.py * Fix how we get 'path' from 'dataset' table in configs, in tests/fixtures/dataset.py * Change .dataset_path -> .dataset.path in tests/ * Fix how we get model config and rename config attribute .dataset_path -> .dataset.path throughout tests * In tests/, fixup change .dataset_path -> .dataset.path, use model.name where we used to use just 'model' attribute of config * Fix fixture specific_config_toml_path in fixtures/config.py to handle case where we need to access sub-table and change a key in it--right now this is just [ 'dataset']['path'] * Fix how we change ['dataset']['path'] value in tests/test_eval/test_frame_classification.py * Fix how we change ['dataset']['path'] value in config in several tests * Use ModelConfig attribute name where needed in tests/test_learncurve/test_frame_classification.py * In tests, replace calls to vak.config.model.config_from_toml_path with calls to ModelConfig method to_dict() * Change cfg.spect_params -> cfg.prep.spect_params in tests * Fix cfg.predict -> cfg.predict.dataset.path in tests/test_predict/test_frame_classification.py * Fix constant LABELSET_NOTMAT in fixtures/annot.py so it is a list of str, not a Tomlkit.String class * Fix cfg.learncurve -> cfg.learncurve.dataset.path in tests/test_prep/test_frame/test_learncurve.py * Fix cfg.learncurve -> cfg.learncurve.dataset.path in tests/test_prep/test_frame/test_learncurve.py * Cast pathlib to str before adding to tomldoc, in tests/test_train/ * Change transform/dataset params keys in data_for_tests/configs to a dataset table with a params key * Add `params` attribute to DatasetConfig * Change transform/dataset params keys in doc/toml/ to a dataset table with a params key * Rewrite vak/config/model.py method 'to_dict' as 'asdict', using attrs asdict function. We now return 'name' and will just get it from the dict instead of having a separate 'model_name' parameter for functions that take 'model_config' * Add asdict method to DatasetConfig class, like ModelConfig.asdict * Fix calls to model.to_dict() -> model.asdict() * Add unit tests for DatasetConfig.asdict * Add unit tests for ModelConfig.asdict * Add an assertion in tests/test_config/test_dataset.py * Remove transform params and dataset_params from EvalConfig, will just use dataset attribute, a DatasetConfig, with its params attribute * Remove dataset/transform_params key-value pairs in valid-version-1.0.toml, and add params key to dataset tables with in-line table params * Remove train/val/dataset/transform_params from TrainConfig, will use DatasetConfig attribute params instead * Remove train/val/dataset/transform_params from PredictConfig, will use DatasetConfig attribute params instead * Revise transforms.defaults.frame_classification.TrainItemTransform and change get_default_frame_classification_transform to return an instance of the TrainItemTransform when 'mode' is 'train' * Make vak.table.dataset.params into an in-line table in toml files in tests/data_for_tests/configs * Fix attribute name in frame_classification.TrainItemTransform.__init__: source_transform -> frames_transform * Rewrite datasets.frame_classification.WindowDataset to require item_transform, and assume that it is an instance of transforms.frame_classification.TrainItemTransform * Rewrite datasets.frame_classification.FramesDataset to make item_transform required * Rewrite src/vak/train/frame_classification.py: remove params model_name, train/val_transform_params, train/val_dataset_params, and dataset_path, replace with dataset_config and just have model_config contain name * Rewrite src/vak/train/_train.py: remove params model_name, train/val_transform_params, train/val_dataset_params, and dataset_path, replace with dataset_config and just have model_config contain name * Rewrite vak/cli/train.py to call train._train.train with just model_config and dataset_config, remove model_name, dataset_path, train/val_transform_params and train/val_dataset_params * Fix how we unpack batch in training_step method of FrameClassificationModel * Change transform_kwargs parameter of transforms.defaults.parametric_umap.get_default_parametric_umap_transform to default to None, and if None to be an empty dict * Change transform_kwargs parameter of transforms.defaults.frame_classification.get_default_frame_classification_transform to default to None, and if None to be an empty dict * Change DatasetConfig.params attribute to default to empty dict, so we can unpack with ** operator even when no params are specified * Fix DatasetConfig.from_config_dict method to not use dict.get method, so we don't set attributes to None inadvertently * Modify transforms.defaults.get so that transform_kwargs is None by default. Also revise docstring and type annotations * Rewrite src/vak/train/parametric_umap.py to use model_config and dataset_config parameters, removing parameters val/train_transform_params + val/train_dataset_params and dataset_path * Rewrite vak/eval/frame_classification.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path * Rewrite vak/eval/parametric_umap.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path * Rewrite vak/eval/eval_.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path * Rewrite cli.eval to pass model_config and dataset_config into eval_module.eval, remove dataset_path/transform_params/datset_params arguments * Unpack dataset_config[params] with ** inside trak/frame_classification.py, instead of directly getting window_size from the params dict * Rewrite vak/learncurve/frame_classification.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path * Rewrite vak/learncurve/learncurve.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path * Rewrite cli.learncurve to pass model_config and dataset_config into learning_curve.learncurve, remove dataset_path/transform_params/datset_params arguments * Rewrite vak/predict/frame_classification.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path * Rewrite vak/predict/parametric_umap.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path * Rewrite vak/predict/predict.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path * Fix dataset_path -> dataset_config[path] and add missing variable model_name in src/vak/learncurve/learncurve.py * Fix dataset_path -> dataset_config[path] and add missing variable model_name in src/vak/learncurve/frame_classification.py * Rewrite vak/cli/predict.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path * Remove non-existent dataset_params variable in vak/predict/frame_classification.py * Fix unit tests for DatasetConfig to test 'params' attribute gets handled correctly * Remove train/val_dataset_params and train/val_transform_params from test cases we parametrize with in tests/test_config/ * Use DatasetConfig.params attribute where we need to in tests/test_datasets * Fix method name ModelConfig.to_dict -> asdict in tests/ * In tests for eval/learncurve/predict/train, use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path * Fix use of default transform and dataset.params attribute in test_models/test_base.py * Fix config snippets in docs * Apply linting to src/ * Raise 'from e' with errors in eval/predict/train/frame_classification modules --- doc/get_started/autoannotate.md | 78 ++--- doc/reference/config.md | 22 +- doc/toml/gy6or6_eval.toml | 22 +- doc/toml/gy6or6_predict.toml | 19 +- doc/toml/gy6or6_train.toml | 29 +- noxfile.py | 4 +- pyproject.toml | 2 +- src/scripts/download_autoannotate_data.py | 1 + src/vak/__main__.py | 1 + src/vak/cli/eval.py | 27 +- src/vak/cli/learncurve.py | 16 +- src/vak/cli/predict.py | 16 +- src/vak/cli/prep.py | 98 +++--- src/vak/cli/train.py | 16 +- src/vak/common/__init__.py | 1 + src/vak/common/constants.py | 1 + src/vak/common/converters.py | 10 +- src/vak/common/labels.py | 7 +- src/vak/common/logging.py | 1 + src/vak/common/paths.py | 1 + src/vak/common/tensorboard.py | 1 + src/vak/common/timebins.py | 1 + src/vak/common/validators.py | 1 + src/vak/config/__init__.py | 25 +- src/vak/config/config.py | 178 +++++++++-- src/vak/config/dataset.py | 59 ++++ src/vak/config/eval.py | 113 ++++--- src/vak/config/learncurve.py | 58 +++- src/vak/config/load.py | 94 ++++++ src/vak/config/model.py | 158 +++++----- src/vak/config/parse.py | 202 ------------- src/vak/config/predict.py | 116 ++++---- src/vak/config/prep.py | 83 ++++-- src/vak/config/spect_params.py | 1 + src/vak/config/train.py | 114 +++---- .../{valid.toml => valid-version-1.0.toml} | 63 ++-- src/vak/config/validators.py | 135 ++++++--- .../frame_classification/frames_dataset.py | 20 +- .../datasets/frame_classification/helper.py | 1 + .../datasets/frame_classification/metadata.py | 1 + .../frame_classification/window_dataset.py | 28 +- src/vak/datasets/parametric_umap/metadata.py | 1 + .../parametric_umap/parametric_umap.py | 1 + src/vak/eval/eval_.py | 40 +-- src/vak/eval/frame_classification.py | 51 ++-- src/vak/eval/parametric_umap.py | 37 +-- src/vak/learncurve/frame_classification.py | 53 +--- src/vak/learncurve/learncurve.py | 30 +- src/vak/metrics/util.py | 1 + src/vak/models/base.py | 1 + src/vak/models/convencoder_umap.py | 1 + src/vak/models/decorator.py | 1 + src/vak/models/definition.py | 1 + src/vak/models/ed_tcn.py | 1 + src/vak/models/frame_classification_model.py | 7 +- src/vak/models/get.py | 1 + src/vak/models/parametric_umap_model.py | 1 + src/vak/models/registry.py | 1 + src/vak/models/tweetynet.py | 1 + src/vak/nets/tweetynet.py | 1 + src/vak/nn/loss/umap.py | 3 +- src/vak/nn/modules/activation.py | 1 + src/vak/nn/modules/conv.py | 1 + src/vak/plot/annot.py | 1 + src/vak/plot/learncurve.py | 1 + src/vak/plot/spect.py | 1 + src/vak/predict/frame_classification.py | 150 +++++----- src/vak/predict/parametric_umap.py | 73 +++-- src/vak/predict/predict_.py | 126 ++++---- src/vak/prep/audio_dataset.py | 8 +- src/vak/prep/constants.py | 1 + src/vak/prep/dataset_df_helper.py | 1 + .../assign_samples_to_splits.py | 1 + .../frame_classification.py | 1 + .../prep/frame_classification/learncurve.py | 1 + .../prep/frame_classification/make_splits.py | 9 +- .../prep/frame_classification/validators.py | 1 + .../prep/parametric_umap/dataset_arrays.py | 1 + src/vak/prep/sequence_dataset.py | 1 + src/vak/prep/spectrogram_dataset/__init__.py | 1 + src/vak/prep/spectrogram_dataset/spect.py | 7 +- .../prep/spectrogram_dataset/spect_helper.py | 9 +- src/vak/prep/split/split.py | 1 + src/vak/prep/unit_dataset/unit_dataset.py | 1 + src/vak/train/frame_classification.py | 91 +++--- src/vak/train/parametric_umap.py | 54 +--- src/vak/train/train_.py | 53 +--- .../defaults/frame_classification.py | 52 ++-- src/vak/transforms/defaults/get.py | 20 +- .../transforms/defaults/parametric_umap.py | 9 +- src/vak/transforms/frame_labels/functional.py | 1 + src/vak/transforms/frame_labels/transforms.py | 1 + ...oderUMAP_eval_audio_cbin_annot_notmat.toml | 11 +- ...derUMAP_train_audio_cbin_annot_notmat.toml | 11 +- ...weetyNet_eval_audio_cbin_annot_notmat.toml | 17 +- ...et_learncurve_audio_cbin_annot_notmat.toml | 20 +- ...tyNet_predict_audio_cbin_annot_notmat.toml | 15 +- ...eetyNet_train_audio_cbin_annot_notmat.toml | 18 +- ...rain_continue_audio_cbin_annot_notmat.toml | 18 +- ...train_continue_spect_mat_annot_yarden.toml | 18 +- ...weetyNet_train_spect_mat_annot_yarden.toml | 18 +- ...on_config.toml => invalid_key_config.toml} | 18 +- ..._config.toml => invalid_table_config.toml} | 10 +- .../invalid_train_and_learncurve_config.toml | 14 +- tests/fixtures/annot.py | 7 +- tests/fixtures/config.py | 280 +++++++++--------- tests/fixtures/csv.py | 6 +- tests/fixtures/dataset.py | 2 +- tests/scripts/vaktestdata/configs.py | 61 ++-- tests/scripts/vaktestdata/source_files.py | 26 +- tests/test_cli/test_eval.py | 22 +- tests/test_cli/test_learncurve.py | 30 +- tests/test_cli/test_predict.py | 20 +- tests/test_cli/test_prep.py | 26 +- tests/test_cli/test_train.py | 20 +- tests/test_config/__init__.py | 2 +- tests/test_config/test_config.py | 144 +++++++-- tests/test_config/test_dataset.py | 133 +++++++++ tests/test_config/test_eval.py | 231 ++++++++++++++- tests/test_config/test_learncurve.py | 204 ++++++++++++- tests/test_config/test_load.py | 24 ++ tests/test_config/test_model.py | 215 ++++++++++---- tests/test_config/test_parse.py | 243 --------------- tests/test_config/test_predict.py | 185 +++++++++++- tests/test_config/test_prep.py | 177 ++++++++++- tests/test_config/test_spect_params.py | 25 +- tests/test_config/test_train.py | 161 +++++++++- tests/test_config/test_validators.py | 25 +- .../test_frames_dataset.py | 9 +- .../test_window_dataset.py | 11 +- .../test_parametric_umap.py | 4 +- tests/test_eval/test_eval.py | 18 +- tests/test_eval/test_frame_classification.py | 70 ++--- tests/test_eval/test_parametric_umap.py | 67 ++--- .../test_frame_classification.py | 40 +-- tests/test_metrics/test_segmentation.py | 0 tests/test_models/test_base.py | 16 +- .../test_frame_classification_model.py | 4 +- .../test_models/test_parametric_umap_model.py | 4 +- .../test_predict/test_frame_classification.py | 82 +++-- tests/test_predict/test_predict.py | 20 +- .../test_assign_samples_to_splits.py | 2 +- .../test_frame_classification.py | 80 ++--- .../test_get_or_make_source_files.py | 6 +- .../test_learncurve.py | 24 +- .../test_make_splits.py | 2 +- tests/test_prep/test_prep.py | 12 +- tests/test_train/test_frame_classification.py | 88 ++---- tests/test_train/test_parametric_umap.py | 63 ++-- tests/test_train/test_train.py | 22 +- 150 files changed, 3455 insertions(+), 2383 deletions(-) create mode 100644 src/vak/config/dataset.py create mode 100644 src/vak/config/load.py delete mode 100644 src/vak/config/parse.py rename src/vak/config/{valid.toml => valid-version-1.0.toml} (68%) rename tests/data_for_tests/configs/{invalid_option_config.toml => invalid_key_config.toml} (61%) rename tests/data_for_tests/configs/{invalid_section_config.toml => invalid_table_config.toml} (72%) create mode 100644 tests/test_config/test_dataset.py create mode 100644 tests/test_config/test_load.py delete mode 100644 tests/test_config/test_parse.py create mode 100644 tests/test_metrics/test_segmentation.py diff --git a/doc/get_started/autoannotate.md b/doc/get_started/autoannotate.md index eab532e89..37ba9eef9 100644 --- a/doc/get_started/autoannotate.md +++ b/doc/get_started/autoannotate.md @@ -20,10 +20,10 @@ Below is an example of some annotated Bengalese finch song, which is what we'll :::{hint} `vak` has built-in support for widely-used annotation formats. -Even if your data is not annotated with one of these formats, -you can use `vak` by converting your annotations to a simple `.csv` format +Even if your data is not annotated with one of these formats, +you can use `vak` by converting your annotations to a simple `.csv` format that is easy to create with Python libraries like `pandas`. -For more information, please see: +For more information, please see: {ref}`howto-user-annot` ::: @@ -42,39 +42,39 @@ Before going through this tutorial, you'll need to: or [notepad++](https://notepad-plus-plus.org/) 3. Download example data from this dataset: - - one day of birdsong, for training data (click to download) + - one day of birdsong, for training data (click to download) {download}`https://figshare.com/ndownloader/files/41668980` - another day, to use to predict annotations (click to download) {download}`https://figshare.com/ndownloader/files/41668983` - - Be sure to extract the files from these archives! - Please use the program "tar" to extract the archives, + - Be sure to extract the files from these archives! + Please use the program "tar" to extract the archives, on either macOS/Linux or Windows. - Using other programs like WinZIP on Windows + Using other programs like WinZIP on Windows can corrupt the files when extracting them, causing confusing errors. Tar should be available on newer Windows systems - (as described + (as described [here](https://learn.microsoft.com/en-us/virtualization/community/team-blog/2017/20171219-tar-and-curl-come-to-windows)). - - Alternatively you can copy the following command and then - paste it into a terminal to run a Python script - that will download and extract the files for you. + - Alternatively you can copy the following command and then + paste it into a terminal to run a Python script + that will download and extract the files for you. :::{eval-rst} - + .. tabs:: - + .. code-tab:: shell macOS / Linux - + curl -sSL https://raw.githubusercontent.com/vocalpy/vak/main/src/scripts/download_autoannotate_data.py | python3 - - + .. code-tab:: shell Windows - + (Invoke-WebRequest -Uri https://raw.githubusercontent.com/vocalpy/vak/main/src/scripts/download_autoannotate_data.py -UseBasicParsing).Content | py - ::: 4. Download the corresponding configuration files (click to download): {download}`gy6or6_train.toml <../toml/gy6or6_train.toml>`, - {download}`gy6or6_eval.toml <../toml/gy6or6_eval.toml>`, + {download}`gy6or6_eval.toml <../toml/gy6or6_eval.toml>`, and {download}`gy6or6_predict.toml <../toml/gy6or6_predict.toml>` ## Overview @@ -181,7 +181,7 @@ Change the part of the path in capital letters to the actual location on your computer: ```toml -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" # we change the next line @@ -230,11 +230,11 @@ When you run `prep`, `vak` converts the data from `data_dir` into a special data automatically adds the path to that file to the `[TRAIN]` section of the `config.toml` file, as the option `csv_path`. -You have now prepared a dataset for training a model! -You'll probably have more questions about -how to do this later, -when you start to work with your own data. -When that time comes, please see the how-to page: +You have now prepared a dataset for training a model! +You'll probably have more questions about +how to do this later, +when you start to work with your own data. +When that time comes, please see the how-to page: {ref}`howto-prep-annotate`. For now, let's move on to training a neural network with this dataset. @@ -294,7 +294,7 @@ from that checkpoint later when we predict annotations for new data. (prepare-prediction-dataset)= -An important step when using neural network models is to evaluate the model's performance +An important step when using neural network models is to evaluate the model's performance on a held-out dataset that has never been used during training, often called the "test" set. Here we show you how to evaluate the model we just trained. @@ -356,33 +356,33 @@ This file will also be found in the root `results_{timestamp}` directory. spect_scaler = "/home/users/You/Data/vak_tutorial_data/vak_output/results_{timestamp}/SpectScaler" ``` -The last path you need is actually in the TOML file that we used +The last path you need is actually in the TOML file that we used to train the neural network: `dataset_path`. -You should copy that `dataset_path` option exactly as it is -and then paste it at the bottom of the `[EVAL]` table +You should copy that `dataset_path` option exactly as it is +and then paste it at the bottom of the `[EVAL]` table in the configuration file for evaluation. -We do this instead of preparing another dataset, -because we already created a test split when we ran +We do this instead of preparing another dataset, +because we already created a test split when we ran `vak prep` with the training configuration. -This is a good practice, because it helps ensure +This is a good practice, because it helps ensure that we do not mix the training data with the test data; -`vak` makes sure that the data from the `data_dir` option +`vak` makes sure that the data from the `data_dir` option is placed in two separate splits, the train and test splits. -Once you have prepared the configuration file as described, +Once you have prepared the configuration file as described, you can run the following in the terminal: ```shell vak eval gy6o6_eval.toml ``` -You will see output to the console as the network is evaluated. -Notice that for this model we evaluate it *with* and *without* -post-processing transforms that clean up the predictions +You will see output to the console as the network is evaluated. +Notice that for this model we evaluate it *with* and *without* +post-processing transforms that clean up the predictions of the model. -The parameters of the post-processing transform are specified +The parameters of the post-processing transform are specified with the `post_tfm_kwargs` option in the configuration file. -You may find this helpful to understand factors affecting +You may find this helpful to understand factors affecting the performance of your own model. ## 4. Preparing a prediction dataset @@ -400,7 +400,7 @@ Just like before, you're going to modify the `data_dir` option of the This time you'll change it to the path to the directory with the other day of data we downloaded. ```toml -[PREP] +[vak.prep] data_dir = "/home/users/You/Data/vak_tutorial_data/032312" ``` @@ -428,7 +428,7 @@ and then add the path to that file as the option `csv_path` in the `[PREDICT]` s Finally you will use the trained network to predict annotations. This is the part that requires you to find paths to files saved by `vak`. -There's three you need. These are the exact same paths we used above +There's three you need. These are the exact same paths we used above in the configuration file for evaluation, so you can copy them from that file. We explain them again here for completeness. All three paths will be in the `results` directory diff --git a/doc/reference/config.md b/doc/reference/config.md index 687f45e2a..dbe0ec9ba 100644 --- a/doc/reference/config.md +++ b/doc/reference/config.md @@ -19,7 +19,7 @@ for each class. ## Valid section names Following is the set of valid section names: -`{PREP, SPECT_PARAMS, DATALOADER, TRAIN, PREDICT, LEARNCURVE}`. +`{eval, learncurve, predict, prep, train}`. In the code, these names correspond to attributes of the main `Config` class, as shown below. @@ -43,50 +43,42 @@ that are considered valid. Valid options for each section are presented below. (ref-config-prep)= -### `[PREP]` section +### `[vak.prep]` section ```{eval-rst} .. autoclass:: vak.config.prep.PrepConfig ``` (ref-config-spect-params)= -### `[SPECT_PARAMS]` section +### `[vak.prep.spect_params]` section ```{eval-rst} .. autoclass:: vak.config.spect_params.SpectParamsConfig ``` -(ref-config-dataloader)= -### `[DATALOADER]` section - -```{eval-rst} -.. autoclass:: vak.config.dataloader.DataLoaderConfig - -``` - (ref-config-train)= -### `[TRAIN]` section +### `[vak.train]` section ```{eval-rst} .. autoclass:: vak.config.train.TrainConfig ``` (ref-config-eval)= -### `[EVAL]` section +### `[vak.eval]` section ```{eval-rst} .. autoclass:: vak.config.eval.EvalConfig ``` (ref-config-predict)= -### `[PREDICT]` section +### `[vak.predict]` section ```{eval-rst} .. autoclass:: vak.config.predict.PredictConfig ``` (ref-config-learncurve)= -### `[LEARNCURVE]` section +### `[vak.learncurve]` section ```{eval-rst} .. autoclass:: vak.config.learncurve.LearncurveConfig diff --git a/doc/toml/gy6or6_eval.toml b/doc/toml/gy6or6_eval.toml index fcd9a7203..71355ed28 100644 --- a/doc/toml/gy6or6_eval.toml +++ b/doc/toml/gy6or6_eval.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] # dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" dataset_type = "frame classification" # input_type: input to model, either audio ("audio") or spectrogram ("spect") @@ -19,7 +19,7 @@ train_dur = 50 val_dur = 15 # SPECT_PARAMS: parameters for computing spectrograms -[SPECT_PARAMS] +[vak.prep.spect_params] # fft_size: size of window used for Fast Fourier Transform, in number of samples fft_size = 512 # step_size: size of step to take when computing spectra with FFT for spectrogram @@ -27,8 +27,7 @@ fft_size = 512 step_size = 64 # EVAL: options for evaluating a trained model. This is done using the "test" split. -[EVAL] -model = "TweetyNet" +[vak.eval] # checkpoint_path: path to saved model checkpoint checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" # labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; @@ -51,7 +50,7 @@ output_dir = "/PATH/TO/FOLDER/results/eval" # ADD THE dataset_path OPTION FROM THE TRAIN FILE HERE (we already created a test split when we ran `vak prep` with that config) # EVAL.post_tfm_kwargs: options for post-processing -[EVAL.post_tfm_kwargs] +[vak.eval.post_tfm_kwargs] # both these transforms require that there is an "unlabeled" label, # and they will only be applied to segments that are bordered on both sides # by the "unlabeled" label. @@ -65,12 +64,11 @@ majority_vote = true # Only applied if this option is specified. min_segment_dur = 0.02 -# transform_params: parameters used when transforming data -# for a frame classification model, we use FrameDataset with the eval_item_transform, -# that reshapes batches into consecutive adjacent windows with a specific `window_size` -[EVAL.transform_params] +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.eval.dataset.params] window_size = 176 -# Note we do not specify any options for the network, and just use the defaults -# We need to put this "dummy" table here though for the config to parse correctly -[TweetyNet] +# Note we do not specify any options for the model, and just use the defaults +# We need to put this table here though so we know which model we are using +[vak.eval.model.TweetyNet] diff --git a/doc/toml/gy6or6_predict.toml b/doc/toml/gy6or6_predict.toml index bcdfbd240..c4c89ef73 100644 --- a/doc/toml/gy6or6_predict.toml +++ b/doc/toml/gy6or6_predict.toml @@ -1,5 +1,5 @@ # PREP: options for preparing dataset -[PREP] +[vak.prep] # dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" dataset_type = "frame classification" # input_type: input to model, either audio ("audio") or spectrogram ("spect") @@ -15,7 +15,7 @@ audio_format = "wav" # all data found in `data_dir` will be assigned to a "predict split" instead # SPECT_PARAMS: parameters for computing spectrograms -[SPECT_PARAMS] +[vak.prep.spect_params] # fft_size: size of window used for Fast Fourier Transform, in number of samples fft_size = 512 # step_size: size of step to take when computing spectra with FFT for spectrogram @@ -23,9 +23,7 @@ fft_size = 512 step_size = 64 # PREDICT: options for generating predictions with a trained model -[PREDICT] -# model: the string name of the model. must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` -model = "TweetyNet" +[vak.predict] # checkpoint_path: path to saved model checkpoint checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" # labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; @@ -61,12 +59,11 @@ majority_vote = true min_segment_dur = 0.01 # dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it -# transform_params: parameters used when transforming data -# for a frame classification model, we use FrameDataset with the eval_item_transform, -# that reshapes batches into consecutive adjacent windows with a specific `window_size` -[PREDICT.transform_params] +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.predict.dataset.params] window_size = 176 # Note we do not specify any options for the network, and just use the defaults -# We need to put this "dummy" table here though for the config to parse correctly -[TweetyNet] +# We need to put this table here though, to indicate which model we are using. +[vak.predict.model.TweetyNet] diff --git a/doc/toml/gy6or6_train.toml b/doc/toml/gy6or6_train.toml index e86b5f7c8..68202f796 100644 --- a/doc/toml/gy6or6_train.toml +++ b/doc/toml/gy6or6_train.toml @@ -1,5 +1,5 @@ # PREP: options for preparing dataset -[PREP] +[vak.prep] # dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" dataset_type = "frame classification" # input_type: input to model, either audio ("audio") or spectrogram ("spect") @@ -22,7 +22,7 @@ val_dur = 15 test_dur = 30 # SPECT_PARAMS: parameters for computing spectrograms -[SPECT_PARAMS] +[vak.prep.spect_params] # fft_size: size of window used for Fast Fourier Transform, in number of samples fft_size = 512 # step_size: size of step to take when computing spectra with FFT for spectrogram @@ -30,9 +30,7 @@ fft_size = 512 step_size = 64 # TRAIN: options for training model -[TRAIN] -# model: the string name of the model. must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` -model = "TweetyNet" +[vak.train] # root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` root_results_dir = "/PATH/TO/FOLDER/results/train" # batch_size: number of samples from dataset per batch fed into network @@ -58,23 +56,20 @@ num_workers = 4 device = "cuda" # dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it -# train_dataset_params: parameters used when loading training dataset -# for a frame classification model, we use a WindowDataset with a specific `window_size` -[TRAIN.train_dataset_params] +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.train.dataset.params] window_size = 176 -# val_transform_params: parameters used when transforming validation data -# for a frame classification model, we use FrameDataset with the eval_item_transform, -# that reshapes batches into consecutive adjacent windows with a specific `window_size` -[TRAIN.val_transform_params] -window_size = 176 - -# TweetyNet.optimizer: we specify options for the model's optimizer in this table -[TweetyNet.optimizer] +# To indicate the model to train, we use a "dotted key" with `model` followed by the string name of the model. +# This name must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` +# We use another dotted key to indicate options for configuring the model, e.g. `TweetyNet.optimizer` +[vak.train.model.TweetyNet.optimizer] +# vak.train.model.TweetyNet.optimizer: we specify options for the model's optimizer in this table # lr: the learning rate lr = 0.001 # TweetyNet.network: we specify options for the model's network in this table -[TweetyNet.network] +[vak.train.model.TweetyNet.network] # hidden_size: the number of elements in the hidden state in the recurrent layer of the network hidden_size = 256 diff --git a/noxfile.py b/noxfile.py index 188fe1281..f931e8803 100644 --- a/noxfile.py +++ b/noxfile.py @@ -61,11 +61,11 @@ def lint(session): """ Run the linter. """ - session.install(".[dev]") + session.install("isort", "black", "flake8") # run isort first since black disagrees with it session.run("isort", "./src") session.run("black", "./src", "--line-length=79") - session.run("flake8", "./src", "--max-line-length", "120", "--exclude", "./src/crowsetta/_vendor") + session.run("flake8", "./src", "--max-line-length", "120") @nox.session diff --git a/pyproject.toml b/pyproject.toml index bacf95d7a..f0895a996 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "SoundFile >=0.10.3", "pandas >=1.0.1", "tensorboard >=2.8.0", - "toml >=0.10.2", + "tomlkit >=0.12.4", "torch >= 2.0.1", "torchvision >=0.15.2", "tqdm >=4.42.1", diff --git a/src/scripts/download_autoannotate_data.py b/src/scripts/download_autoannotate_data.py index 2d8ea268f..cca3c6e14 100644 --- a/src/scripts/download_autoannotate_data.py +++ b/src/scripts/download_autoannotate_data.py @@ -3,6 +3,7 @@ Adapted from https://github.com/NickleDave/bfsongrepo/blob/main/src/scripts/download_dataset.py """ + from __future__ import annotations import argparse diff --git a/src/vak/__main__.py b/src/vak/__main__.py index c3f6c0bac..a25d3f833 100644 --- a/src/vak/__main__.py +++ b/src/vak/__main__.py @@ -2,6 +2,7 @@ Invokes __main__ when the module is run as a script. Example: python -m vak --help """ + import argparse from pathlib import Path diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index 329bf38b0..29bee65a5 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -1,5 +1,9 @@ +"""Evaluate a trained model with dataset specified in config.toml file.""" + +from __future__ import annotations + import logging -from pathlib import Path +import pathlib from .. import config from .. import eval as eval_module @@ -8,8 +12,9 @@ logger = logging.getLogger(__name__) -def eval(toml_path): - """evaluate a trained model with dataset specified in config.toml file. +def eval(toml_path: str | pathlib.Path) -> None: + """Evaluate a trained model with dataset specified in config.toml file. + Function called by command-line interface. Parameters @@ -21,8 +26,8 @@ def eval(toml_path): ------- None """ - toml_path = Path(toml_path) - cfg = config.parse.from_toml_path(toml_path) + toml_path = pathlib.Path(toml_path) + cfg = config.Config.from_toml_path(toml_path) if cfg.eval is None: raise ValueError( @@ -37,10 +42,7 @@ def eval(toml_path): logger.info("Logging results to {}".format(cfg.eval.output_dir)) - model_name = cfg.eval.model - model_config = config.model.config_from_toml_path(toml_path, model_name) - - if cfg.eval.dataset_path is None: + if cfg.eval.dataset.path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." f"To generate a .csv file that represents the dataset, " @@ -48,16 +50,13 @@ def eval(toml_path): ) eval_module.eval( - model_name=model_name, - model_config=model_config, - dataset_path=cfg.eval.dataset_path, + model_config=cfg.eval.model.asdict(), + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, batch_size=cfg.eval.batch_size, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, spect_scaler_path=cfg.eval.spect_scaler_path, device=cfg.eval.device, post_tfm_kwargs=cfg.eval.post_tfm_kwargs, diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py index 33c293a61..2decc5cd8 100644 --- a/src/vak/cli/learncurve.py +++ b/src/vak/cli/learncurve.py @@ -23,7 +23,7 @@ def learning_curve(toml_path): path to a configuration file in TOML format. """ toml_path = Path(toml_path) - cfg = config.parse.from_toml_path(toml_path) + cfg = config.Config.from_toml_path(toml_path) if cfg.learncurve is None: raise ValueError( @@ -45,10 +45,7 @@ def learning_curve(toml_path): log_version(logger) logger.info("Logging results to {}".format(results_path)) - model_name = cfg.learncurve.model - model_config = config.model.config_from_toml_path(toml_path, model_name) - - if cfg.learncurve.dataset_path is None: + if cfg.learncurve.dataset.path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." f"To generate a .csv file that represents the dataset, " @@ -56,16 +53,11 @@ def learning_curve(toml_path): ) learncurve.learning_curve( - model_name=model_name, - model_config=model_config, - dataset_path=cfg.learncurve.dataset_path, + model_config=cfg.learncurve.model.asdict(), + dataset_config=cfg.learncurve.dataset.asdict(), batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, - train_transform_params=cfg.learncurve.train_transform_params, - train_dataset_params=cfg.learncurve.train_dataset_params, - val_transform_params=cfg.learncurve.val_transform_params, - val_dataset_params=cfg.learncurve.val_dataset_params, results_path=results_path, post_tfm_kwargs=cfg.learncurve.post_tfm_kwargs, normalize_spectrograms=cfg.learncurve.normalize_spectrograms, diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index 38701b87f..01c0e2612 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -18,7 +18,7 @@ def predict(toml_path): path to a configuration file in TOML format. """ toml_path = Path(toml_path) - cfg = config.parse.from_toml_path(toml_path) + cfg = config.Config.from_toml_path(toml_path) if cfg.predict is None: raise ValueError( @@ -35,10 +35,7 @@ def predict(toml_path): log_version(logger) logger.info("Logging results to {}".format(cfg.prep.output_dir)) - model_name = cfg.predict.model - model_config = config.model.config_from_toml_path(toml_path, model_name) - - if cfg.predict.dataset_path is None: + if cfg.predict.dataset.path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." f"To generate a .csv file that represents the dataset, " @@ -46,15 +43,12 @@ def predict(toml_path): ) predict_module.predict( - model_name=model_name, - model_config=model_config, - dataset_path=cfg.predict.dataset_path, + model_config=cfg.predict.model.asdict(), + dataset_config=cfg.predict.dataset.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, - transform_params=cfg.predict.transform_params, - dataset_params=cfg.predict.dataset_params, - timebins_key=cfg.spect_params.timebins_key, + timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, diff --git a/src/vak/cli/prep.py b/src/vak/cli/prep.py index 3a8ee6b8d..d86c4c0a9 100644 --- a/src/vak/cli/prep.py +++ b/src/vak/cli/prep.py @@ -1,48 +1,58 @@ """Function called by command-line interface for prep command""" + from __future__ import annotations import pathlib import shutil import warnings -import toml +import tomlkit from .. import config from .. import prep as prep_module -from ..config.parse import _load_toml_from_path -from ..config.validators import are_sections_valid +from ..config.load import _load_toml_from_path +from ..config.validators import are_tables_valid def purpose_from_toml( - config_toml: dict, toml_path: str | pathlib.Path | None = None + config_dict: dict, toml_path: str | pathlib.Path | None = None ) -> str: - """determine "purpose" from toml config, + """Determine "purpose" from toml config, i.e., the command that will be run after we ``prep`` the data. - By convention this is the other section in the config file - that correspond to a cli command besides '[PREP]' + By convention this is the other top-level table in the config file + that correspond to a cli command besides ``[vak.prep]``, e.g. ``[vak.train]``. """ # validate, make sure there aren't multiple commands in one config file first - are_sections_valid(config_toml, toml_path=toml_path) + are_tables_valid(config_dict, toml_path=toml_path) + config_dict = config_dict from ..cli.cli import CLI_COMMANDS # avoid circular imports - commands_that_are_not_prep = ( + commands_that_are_not_prep = [ command for command in CLI_COMMANDS if command != "prep" - ) - for command in commands_that_are_not_prep: - section_name = ( - command.upper() - ) # we write section names in uppercase, e.g. `[PREP]`, by convention - if section_name in config_toml: - return section_name.lower() # this is the "purpose" of the file + ] + purpose = None + for table_name in commands_that_are_not_prep: + if table_name in config_dict: + purpose = ( + table_name # this top-level table is the "purpose" of the file + ) + if purpose is None: + raise ValueError( + "Did not find a top-level table in configuration file that corresponds to a CLI command. " + f"Configuration file path: {toml_path}\n" + f"Found the following top-level tables: {config_dict.keys()}\n" + f"Valid CLI commands besides ``prep`` (that correspond top-level tables) are: {commands_that_are_not_prep}" + ) + return purpose # note NO LOGGING -- we configure logger inside `core.prep` # so we can save log file inside dataset directory # see https://github.com/NickleDave/vak/issues/334 -SECTIONS_PREP_SHOULD_PARSE = ("PREP", "SPECT_PARAMS", "DATALOADER") +TABLES_PREP_SHOULD_PARSE = "prep" def prep(toml_path): @@ -83,52 +93,52 @@ def prep(toml_path): """ toml_path = pathlib.Path(toml_path) - # open here because need to check for `dataset_path` in this function, see #314 & #333 - config_toml = _load_toml_from_path(toml_path) - # ---- figure out purpose of config file from sections; will save csv path in that section ------------------------- - purpose = purpose_from_toml(config_toml, toml_path) + # open here because need to check whether the `dataset` already has a `path`, see #314 & #333 + config_dict = _load_toml_from_path(toml_path) + + # ---- figure out purpose of config file from tables; will save path of prep'd dataset in that table --------------- + purpose = purpose_from_toml(config_dict, toml_path) if ( - "dataset_path" in config_toml[purpose.upper()] - and config_toml[purpose.upper()]["dataset_path"] is not None + "dataset" in config_dict[purpose] + and "path" in config_dict[purpose]["dataset"] ): raise ValueError( - f"config .toml file already has a 'dataset_path' option in the '{purpose.upper()}' section, " - f"and running `prep` would overwrite that value. To `prep` a new dataset, please remove " - f"the 'dataset_path' option from the '{purpose.upper()}' section in the config file:\n{toml_path}" + f"This configuration file already has a '{purpose}.dataset' table with a 'path' key, " + f"and running `prep` would overwrite the value for that key. To `prep` a new dataset, please " + "either create a new configuration file, or remove " + f"the 'path' key-value pair from the '{purpose}.dataset' table in the file:\n{toml_path}" ) - # now that we've checked that, go ahead and parse the sections we want - cfg = config.parse.from_toml_path( - toml_path, sections=SECTIONS_PREP_SHOULD_PARSE - ) - # notice we ignore any other option/values in the 'purpose' section, + # now that we've checked that, go ahead and parse just the prep tabel; + # we don't load the 'purpose' table into a config, to avoid error messages like non-existent paths, etc. # see https://github.com/NickleDave/vak/issues/334 and https://github.com/NickleDave/vak/issues/314 + cfg = config.Config.from_toml_path( + toml_path, tables_to_parse=TABLES_PREP_SHOULD_PARSE + ) if cfg.prep is None: raise ValueError( - f"prep called with a config.toml file that does not have a PREP section: {toml_path}" + f"prep called with a config.toml file that does not have a [vak.prep] table: {toml_path}" ) if purpose == "predict": if cfg.prep.labelset is not None: warnings.warn( - "config has a PREDICT section, but labelset option is specified in PREP section." - "This would cause an error because the dataframe.from_files section will attempt to " + "config has a [vak.predict] table, but labelset option is specified in [vak.prep] table." + "This would cause an error because the dataframe.from_files method will attempt to " f"check whether the files in the data_dir ({cfg.prep.data_dir}) have labels in " "labelset, even though those files don't have annotation.\n" "Setting labelset to None." ) cfg.prep.labelset = None - section = purpose.upper() - - dataset_df, dataset_path = prep_module.prep( + _, dataset_path = prep_module.prep( data_dir=cfg.prep.data_dir, purpose=purpose, dataset_type=cfg.prep.dataset_type, input_type=cfg.prep.input_type, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -141,11 +151,15 @@ def prep(toml_path): num_replicates=cfg.prep.num_replicates, ) - # use config and section from above to add dataset_path to config.toml file - config_toml[section]["dataset_path"] = str(dataset_path) - + # we re-open config using tomlkit so we can add path to dataset table in style-preserving way + with toml_path.open("r") as fp: + tomldoc = tomlkit.load(fp) + if "dataset" not in tomldoc["vak"][purpose]: + dataset_table = tomlkit.table() + tomldoc["vak"][purpose].add("dataset", dataset_table) + tomldoc["vak"][purpose]["dataset"].add("path", str(dataset_path)) with toml_path.open("w") as fp: - toml.dump(config_toml, fp) + tomlkit.dump(tomldoc, fp) # lastly, copy config to dataset directory root shutil.copy(src=toml_path, dst=dataset_path) diff --git a/src/vak/cli/train.py b/src/vak/cli/train.py index 91c89bb95..c63096ca2 100644 --- a/src/vak/cli/train.py +++ b/src/vak/cli/train.py @@ -23,7 +23,7 @@ def train(toml_path): path to a configuration file in TOML format. """ toml_path = Path(toml_path) - cfg = config.parse.from_toml_path(toml_path) + cfg = config.Config.from_toml_path(toml_path) if cfg.train is None: raise ValueError( @@ -45,10 +45,7 @@ def train(toml_path): log_version(logger) logger.info("Logging results to {}".format(results_path)) - model_name = cfg.train.model - model_config = config.model.config_from_toml_path(toml_path, model_name) - - if cfg.train.dataset_path is None: + if cfg.train.dataset.path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." f"To generate a .csv file that represents the dataset, " @@ -56,13 +53,8 @@ def train(toml_path): ) train_module.train( - model_name=model_name, - model_config=model_config, - dataset_path=cfg.train.dataset_path, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, + model_config=cfg.train.model.asdict(), + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, diff --git a/src/vak/common/__init__.py b/src/vak/common/__init__.py index e453adbb6..777bd9afb 100644 --- a/src/vak/common/__init__.py +++ b/src/vak/common/__init__.py @@ -7,6 +7,7 @@ See for example :mod:`vak.prep.prep_helper` or :mod:`vak.datsets.window_dataset._helper`. """ + from . import ( annotation, constants, diff --git a/src/vak/common/constants.py b/src/vak/common/constants.py index fcf2ab7d6..a3a34315f 100644 --- a/src/vak/common/constants.py +++ b/src/vak/common/constants.py @@ -1,6 +1,7 @@ """constants used by multiple modules. Defined here to avoid circular imports. """ + from functools import partial import crowsetta diff --git a/src/vak/common/converters.py b/src/vak/common/converters.py index 6b349e182..8d5735649 100644 --- a/src/vak/common/converters.py +++ b/src/vak/common/converters.py @@ -52,10 +52,12 @@ def range_str(range_str, sort=True): subrange, substr ) ) - list_range.extend([int(subrange[0])]) if len( - subrange - ) == 1 else list_range.extend( - range(int(subrange[0]), int(subrange[1]) + 1) + ( + list_range.extend([int(subrange[0])]) + if len(subrange) == 1 + else list_range.extend( + range(int(subrange[0]), int(subrange[1]) + 1) + ) ) if sort: diff --git a/src/vak/common/labels.py b/src/vak/common/labels.py index dd515df40..5f851cdec 100644 --- a/src/vak/common/labels.py +++ b/src/vak/common/labels.py @@ -172,8 +172,7 @@ def multi_char_labels_to_single_char( # which would map it to a new integer and cause us to lose the original integer # from the mapping single_char_labels_not_in_labelmap = [ - lbl for lbl in DUMMY_SINGLE_CHAR_LABELS - if lbl not in labelmap + lbl for lbl in DUMMY_SINGLE_CHAR_LABELS if lbl not in labelmap ] n_needed_to_remap = len( [lbl for lbl in current_str_labels if len(lbl) > 1] @@ -187,7 +186,9 @@ def multi_char_labels_to_single_char( new_labelmap = {} for dummy_label_ind, label_str in enumerate(current_str_labels): label_int = labelmap[label_str] - if len(label_str) > 1 and label_str not in skip: # default for `skip` is ('unlabeled',) + if ( + len(label_str) > 1 and label_str not in skip + ): # default for `skip` is ('unlabeled',) # replace with dummy label new_label_str = single_char_labels_not_in_labelmap[dummy_label_ind] new_labelmap[new_label_str] = label_int diff --git a/src/vak/common/logging.py b/src/vak/common/logging.py index 8ced29688..fa65f272d 100644 --- a/src/vak/common/logging.py +++ b/src/vak/common/logging.py @@ -1,4 +1,5 @@ """utility functions for logging""" + import logging import sys import warnings diff --git a/src/vak/common/paths.py b/src/vak/common/paths.py index 212ad32fc..12648f393 100644 --- a/src/vak/common/paths.py +++ b/src/vak/common/paths.py @@ -1,4 +1,5 @@ """functions for working with paths""" + from pathlib import Path from . import constants, timenow diff --git a/src/vak/common/tensorboard.py b/src/vak/common/tensorboard.py index 6e6b50d88..43db0e53e 100644 --- a/src/vak/common/tensorboard.py +++ b/src/vak/common/tensorboard.py @@ -1,4 +1,5 @@ """Functions dealing with ``tensorboard``""" + from __future__ import annotations from pathlib import Path diff --git a/src/vak/common/timebins.py b/src/vak/common/timebins.py index dd1d8375a..afef34e55 100644 --- a/src/vak/common/timebins.py +++ b/src/vak/common/timebins.py @@ -1,5 +1,6 @@ """module for functions that deal with vector of times from a spectrogram, i.e. where elements are the times at bin centers""" + import numpy as np diff --git a/src/vak/common/validators.py b/src/vak/common/validators.py index ecc7f1d5d..b51399bc1 100644 --- a/src/vak/common/validators.py +++ b/src/vak/common/validators.py @@ -1,4 +1,5 @@ """Functions for input validation""" + import pathlib import warnings diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index 8f20f2224..056c0ef12 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -1,26 +1,47 @@ """sub-package that parses config.toml files and returns config object""" + from . import ( config, + dataset, eval, learncurve, + load, model, - parse, predict, prep, spect_params, train, validators, ) +from .config import Config +from .dataset import DatasetConfig +from .eval import EvalConfig +from .learncurve import LearncurveConfig +from .model import ModelConfig +from .predict import PredictConfig +from .prep import PrepConfig +from .spect_params import SpectParamsConfig +from .train import TrainConfig __all__ = [ "config", + "dataset", "eval", "learncurve", "model", - "parse", + "load", "predict", "prep", "spect_params", "train", "validators", + "Config", + "DatasetConfig", + "EvalConfig", + "LearncurveConfig", + "ModelConfig", + "PredictConfig", + "PrepConfig", + "SpectParamsConfig", + "TrainConfig", ] diff --git a/src/vak/config/config.py b/src/vak/config/config.py index 377802b3b..553afb464 100644 --- a/src/vak/config/config.py +++ b/src/vak/config/config.py @@ -1,43 +1,183 @@ -import attr +"""Class that represents the TOML configuration file used with the vak command-line interface.""" + +from __future__ import annotations + +import pathlib + from attr.validators import instance_of, optional +from attrs import define, field +from . import load from .eval import EvalConfig from .learncurve import LearncurveConfig from .predict import PredictConfig from .prep import PrepConfig -from .spect_params import SpectParamsConfig from .train import TrainConfig +from .validators import are_keys_valid, are_tables_valid + +TABLE_CLASSES_MAP = { + "eval": EvalConfig, + "learncurve": LearncurveConfig, + "predict": PredictConfig, + "prep": PrepConfig, + "train": TrainConfig, +} + +def _validate_tables_to_parse_arg_convert_list( + tables_to_parse: str | list[str], +) -> list[str]: + """Helper function used by :func:`from_toml` that + validates the ``tables_to_parse`` argument, + and returns it as a list of strings.""" + if isinstance(tables_to_parse, str): + tables_to_parse = [tables_to_parse] -@attr.s + if not isinstance(tables_to_parse, list): + raise TypeError( + f"`tables_to_parse` should be a string or list of strings but type was: {type(tables_to_parse)}" + ) + + if not all( + [isinstance(table_name, str) for table_name in tables_to_parse] + ): + raise ValueError( + "All table names in 'tables_to_parse' should be strings" + ) + if not all( + [ + table_name in list(TABLE_CLASSES_MAP.keys()) + for table_name in tables_to_parse + ] + ): + raise ValueError( + "All table names in 'tables_to_parse' should be valid names of tables. " + f"Values for 'tables were: {tables_to_parse}.\n" + f"Valid table names are: {list(TABLE_CLASSES_MAP.keys())}" + ) + return tables_to_parse + + +@define class Config: - """class to represent config.toml file + """Class that represents the TOML configuration file used with the vak command-line interface. Attributes ---------- prep : vak.config.prep.PrepConfig - represents ``[PREP]`` section of config.toml file - spect_params : vak.config.spect_params.SpectParamsConfig - represents ``[SPECT_PARAMS]`` section of config.toml file + Represents ``[vak.prep]`` table of config.toml file train : vak.config.train.TrainConfig - represents ``[TRAIN]`` section of config.toml file + Represents ``[vak.train]`` table of config.toml file eval : vak.config.eval.EvalConfig - represents ``[EVAL]`` section of config.toml file + Represents ``[vak.eval]`` table of config.toml file predict : vak.config.predict.PredictConfig - represents ``[PREDICT]`` section of config.toml file. + Represents ``[vak.predict]`` table of config.toml file. learncurve : vak.config.learncurve.LearncurveConfig - represents ``[LEARNCURVE]`` section of config.toml file + Represents ``[vak.learncurve]`` table of config.toml file """ - spect_params = attr.ib( - validator=instance_of(SpectParamsConfig), default=SpectParamsConfig() - ) - prep = attr.ib(validator=optional(instance_of(PrepConfig)), default=None) - train = attr.ib(validator=optional(instance_of(TrainConfig)), default=None) - eval = attr.ib(validator=optional(instance_of(EvalConfig)), default=None) - predict = attr.ib( + prep = field(validator=optional(instance_of(PrepConfig)), default=None) + train = field(validator=optional(instance_of(TrainConfig)), default=None) + eval = field(validator=optional(instance_of(EvalConfig)), default=None) + predict = field( validator=optional(instance_of(PredictConfig)), default=None ) - learncurve = attr.ib( + learncurve = field( validator=optional(instance_of(LearncurveConfig)), default=None ) + + @classmethod + def from_config_dict( + cls, + config_dict: dict, + tables_to_parse: str | list[str] | None = None, + toml_path: str | pathlib.Path | None = None, + ) -> "Config": + """Return instance of :class:`Config` class, + given a :class:`dict` containing the contents of + a TOML configuration file. + + This :func:`classmethod` expects the output + of :func:`vak.config.load._load_from_toml_path`, + that converts a :class:`tomlkit.TOMLDocument` + to a :class:`dict`, and returns the :class:`dict` + that is accessed by the top-level key ``'vak'``. + + Parameters + ---------- + config_dict : dict + Python ``dict`` containing a .toml configuration file, + parsed by the ``toml`` library. + toml_path : str, pathlib.Path + path to a configuration file in TOML format. Default is None. + Not required, used only to make any error messages clearer. + tables_to_parse : str, list + Name of top-level table or tables from configuration + file that should be parsed. Can be a string + (single table) or list of strings (multiple + tables). Default is None, + in which case all are validated and parsed. + + Returns + ------- + config : vak.config.parse.Config + instance of :class:`Config` class, + whose attributes correspond to the + top-level tables in a config.toml file. + """ + are_tables_valid(config_dict, toml_path) + if tables_to_parse is None: + tables_to_parse = list( + config_dict.keys() + ) # i.e., parse all top-level tables + else: + tables_to_parse = _validate_tables_to_parse_arg_convert_list( + tables_to_parse + ) + + config_kwargs = {} + for table_name in tables_to_parse: + if table_name in config_dict: + are_keys_valid(config_dict, table_name, toml_path) + table_config_dict = config_dict[table_name] + config_kwargs[table_name] = TABLE_CLASSES_MAP[ + table_name + ].from_config_dict(table_config_dict) + else: + raise KeyError( + f"A table specified in `tables_to_parse` was not found in the config: {table_name}" + ) + + return cls(**config_kwargs) + + @classmethod + def from_toml_path( + cls, + toml_path: str | pathlib.Path, + tables_to_parse: list[str] | None = None, + ) -> "Config": + """Return instance of :class:`Config` class, + given the path to a TOML configuration file. + + Parameters + ---------- + toml_path : str, pathlib.Path + Path to a configuration file in TOML format. + Parsed by ``toml`` library, then converted to an + instance of ``vak.config.parse.Config`` by + calling ``vak.parse.from_toml`` + tables_to_parse : str, list + Name of table or tables from configuration + file that should be parsed. Can be a string + (single table) or list of strings (multiple + tables). Default is None, + in which case all are validated and parsed. + + Returns + ------- + config : vak.config.parse.Config + instance of :class:`Config` class, whose attributes correspond to + tables in a config.toml file. + """ + config_dict: dict = load._load_toml_from_path(toml_path) + return cls.from_config_dict(config_dict, tables_to_parse, toml_path) diff --git a/src/vak/config/dataset.py b/src/vak/config/dataset.py new file mode 100644 index 000000000..f75c34b73 --- /dev/null +++ b/src/vak/config/dataset.py @@ -0,0 +1,59 @@ +"""Class that represents dataset table in configuration file.""" + +from __future__ import annotations + +import pathlib + +import attr.validators +from attr import asdict, define, field + + +@define +class DatasetConfig: + """Class that represents dataset table in configuration file. + + Attributes + ---------- + path : pathlib.Path + Path to the directory that contains the dataset. + Equivalent to the `root` parameter of :mod:`torchvision` + datasets. + splits_path : pathlib.Path, optional + Path to file representing splits. + Default is None. + name : str, optional + Name of dataset. Only required for built-in datasets + from the :mod:`~vak.datasets` module. Default is None. + params: dict, optional + Parameters for dataset class, + passed in as keyword arguments. + E.g., ``window_size=2000``. + Default is None. + """ + + path: pathlib.Path = field(converter=pathlib.Path) + splits_path: pathlib.Path | None = field( + converter=attr.converters.optional(pathlib.Path), default=None + ) + name: str | None = field( + converter=attr.converters.optional(str), default=None + ) + params: dict | None = field( + # we default to an empty dict instead of None + # so we can still do **['dataset']['params'] everywhere we do when params are specified + converter=attr.converters.optional(dict), + default={}, + ) + + @classmethod + def from_config_dict(cls, config_dict: dict) -> DatasetConfig: + + return cls(**config_dict) + + def asdict(self): + """Convert this :class:`DatasetConfig` instance + to a :class:`dict` that can be passed + into functions that take a ``dataset_config`` argument, + like :func:`vak.train` and :func:`vak.predict`. + """ + return asdict(self) diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index 6991b89a7..3012da649 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -1,11 +1,16 @@ -"""parses [EVAL] section of config""" -import attr -from attr import converters, validators -from attr.validators import instance_of +"""Class and functions for ``[vak.eval]`` table in configuration file.""" + +from __future__ import annotations + +import pathlib + +from attrs import converters, define, field, validators +from attrs.validators import instance_of from ..common import device from ..common.converters import expanded_user_path -from .validators import is_valid_model_name +from .dataset import DatasetConfig +from .model import ModelConfig def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict: @@ -67,22 +72,35 @@ def are_valid_post_tfm_kwargs(instance, attribute, value): ) -@attr.s +REQUIRED_KEYS = ( + "checkpoint_path", + "dataset", + "output_dir", + "model", +) + + +@define class EvalConfig: - """class that represents [EVAL] section of config.toml file + """Class that represents [vak.eval] table in configuration file. Attributes ---------- - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. + dataset : vak.config.DatasetConfig + The dataset to use: the path to it, + and optionally a path to a file representing splits, + and the name, if it is a built-in dataset. + Must be an instance of :class:`vak.config.DatasetConfig`. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model output_dir : str Path to location where .csv files with evaluation metrics should be saved. labelmap_path : str path to 'labelmap.json' file. - model : str - Model name, e.g., ``model = "TweetyNet"`` + model : vak.config.ModelConfig + The model to use: its name, + and the parameters to configure it. + Must be an instance of :class:`vak.config.ModelConfig` batch_size : int number of samples per batch presented to models during training. num_workers : int @@ -108,63 +126,64 @@ class EvalConfig: a float value for ``min_segment_dur``. See the docstring of the transform for more details on these arguments and how they work. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. """ # required, external files - checkpoint_path = attr.ib(converter=expanded_user_path) - output_dir = attr.ib(converter=expanded_user_path) + checkpoint_path: pathlib.Path = field(converter=expanded_user_path) + output_dir: pathlib.Path = field(converter=expanded_user_path) # required, model / dataloader - model = attr.ib( - validator=[instance_of(str), is_valid_model_name], + model = field( + validator=instance_of(ModelConfig), ) - batch_size = attr.ib(converter=int, validator=instance_of(int)) - - # dataset_path is actually 'required' but we can't enforce that here because cli.prep looks at - # what sections are defined to figure out where to add dataset_path after it creates the csv - dataset_path = attr.ib( - converter=converters.optional(expanded_user_path), - default=None, + batch_size = field(converter=int, validator=instance_of(int)) + dataset: DatasetConfig = field( + validator=instance_of(DatasetConfig), ) # "optional" but actually required for frame classification models # TODO: check model family in __post_init__ and raise ValueError if labelmap # TODO: not specified for a frame classification model? - labelmap_path = attr.ib( + labelmap_path = field( converter=converters.optional(expanded_user_path), default=None ) # optional, transform - spect_scaler_path = attr.ib( + spect_scaler_path = field( converter=converters.optional(expanded_user_path), default=None, ) - post_tfm_kwargs = attr.ib( + post_tfm_kwargs = field( validator=validators.optional(are_valid_post_tfm_kwargs), converter=converters.optional(convert_post_tfm_kwargs), default=None, ) # optional, data loader - num_workers = attr.ib(validator=instance_of(int), default=2) - device = attr.ib(validator=instance_of(str), default=device.get_default()) - - transform_params = attr.ib( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) - - dataset_params = attr.ib( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) + num_workers = field(validator=instance_of(int), default=2) + device = field(validator=instance_of(str), default=device.get_default()) + + @classmethod + def from_config_dict(cls, config_dict: dict) -> EvalConfig: + """Return :class:`EvalConfig` instance from a :class:`dict`. + + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using key ``eval``, + i.e., ``EvalConfig.from_config_dict(config_dict['eval'])``.""" + for required_key in REQUIRED_KEYS: + if required_key not in config_dict: + raise KeyError( + "The `[vak.eval]` table in a configuration file requires " + f"the option '{required_key}', but it was not found " + "when loading the configuration file into a Python dictionary. " + "Please check that the configuration file is formatted correctly." + ) + config_dict["dataset"] = DatasetConfig.from_config_dict( + config_dict["dataset"] + ) + config_dict["model"] = ModelConfig.from_config_dict( + config_dict["model"] + ) + return cls(**config_dict) diff --git a/src/vak/config/learncurve.py b/src/vak/config/learncurve.py index 13fc6021a..fdf2b883a 100644 --- a/src/vak/config/learncurve.py +++ b/src/vak/config/learncurve.py @@ -1,21 +1,32 @@ -"""parses [LEARNCURVE] section of config""" -import attr -from attr import converters, validators +"""Class that represents ``[vak.learncurve]`` table in configuration file.""" +from __future__ import annotations + +from attrs import converters, define, field, validators + +from .dataset import DatasetConfig from .eval import are_valid_post_tfm_kwargs, convert_post_tfm_kwargs +from .model import ModelConfig from .train import TrainConfig +REQUIRED_KEYS = ("dataset", "model", "root_results_dir") + -@attr.s +@define class LearncurveConfig(TrainConfig): - """class that represents [LEARNCURVE] section of config.toml file + """Class that represents ``[vak.learncurve]`` table in configuration file. Attributes ---------- - model : str - Model name, e.g., ``model = "TweetyNet"`` - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. + model : vak.config.ModelConfig + The model to use: its name, + and the parameters to configure it. + Must be an instance of :class:`vak.config.ModelConfig` + dataset : vak.config.DatasetConfig + The dataset to use: the path to it, + and optionally a path to a file representing splits, + and the name, if it is a built-in dataset. + Must be an instance of :class:`vak.config.DatasetConfig`. num_epochs : int number of training epochs. One epoch = one iteration through the entire training set. @@ -51,8 +62,35 @@ class LearncurveConfig(TrainConfig): these arguments and how they work. """ - post_tfm_kwargs = attr.ib( + post_tfm_kwargs = field( validator=validators.optional(are_valid_post_tfm_kwargs), converter=converters.optional(convert_post_tfm_kwargs), default=None, ) + + # we over-ride this method from TrainConfig mainly so the docstring is correct. + # TODO: can we do this by just over-writing `__doc__` for the method on this class? + @classmethod + def from_config_dict(cls, config_dict: dict) -> "TrainConfig": + """Return :class:`LearncurveConfig` instance from a :class:`dict`. + + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using key ``prep``, + i.e., ``LearncurveConfig.from_config_dict(config_dict['train'])``.""" + for required_key in REQUIRED_KEYS: + if required_key not in config_dict: + raise KeyError( + "The `[vak.train]` table in a configuration file requires " + f"the option '{required_key}', but it was not found " + "when loading the configuration file into a Python dictionary. " + "Please check that the configuration file is formatted correctly." + ) + config_dict["model"] = ModelConfig.from_config_dict( + config_dict["model"] + ) + config_dict["dataset"] = DatasetConfig.from_config_dict( + config_dict["dataset"] + ) + return cls(**config_dict) diff --git a/src/vak/config/load.py b/src/vak/config/load.py new file mode 100644 index 000000000..3134dc85e --- /dev/null +++ b/src/vak/config/load.py @@ -0,0 +1,94 @@ +"""Functions to parse toml config files.""" + +from __future__ import annotations + +import pathlib + +import tomlkit +import tomlkit.exceptions + + +def _tomlkit_to_popo(d): + """Convert tomlkit to "popo" (Plain-Old Python Objects) + + From https://github.com/python-poetry/tomlkit/issues/43#issuecomment-660415820 + + We need this so we don't get a ``tomlkit.items._ConvertError`` when + the `from_config_dict` classmethods try to add a class to a ``config_dict``, + e.g. when :meth:`EvalConfig.from_config_dict` converts the ``spect_params`` + key-value pairs to a :class:`vak.config.SpectParamsConfig` instance + and then assigns it to the ``spect_params`` key. + We would get this error if we just return the result of :func:`tomlkit.load`, + which is a `tomlkit.TOMLDocument` that tries to ensure that everything is valid toml. + """ + try: + result = getattr(d, "value") + except AttributeError: + result = d + + if isinstance(result, list): + result = [_tomlkit_to_popo(x) for x in result] + elif isinstance(result, dict): + result = { + _tomlkit_to_popo(key): _tomlkit_to_popo(val) + for key, val in result.items() + } + elif isinstance(result, tomlkit.items.Integer): + result = int(result) + elif isinstance(result, tomlkit.items.Float): + result = float(result) + elif isinstance(result, tomlkit.items.String): + result = str(result) + elif isinstance(result, tomlkit.items.Bool): + result = bool(result) + + return result + + +def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: + """Load a toml file from a path, and return as a :class:`dict`. + + Notes + ----- + Helper function to load toml config file, + factored out to use in other modules when needed. + Checks if ``toml_path`` exists before opening, + and tries to give a clear message if an error occurs when loading. + + Note also this function checks that the loaded :class:`dict` + has a single top-level key ``'vak'``, + and that it returns the :class:`dict` one level down + that is accessed with that key. + This avoids the need to write ``['vak']`` everywhere in + calling functions. + However it also means you need to add back that key + if you are *writing* a toml file. + """ + toml_path = pathlib.Path(toml_path) + if not toml_path.is_file(): + raise FileNotFoundError(f".toml config file not found: {toml_path}") + + try: + with toml_path.open("r") as fp: + config_dict: dict = tomlkit.load(fp) + except tomlkit.exceptions.TOMLKitError as e: + raise Exception( + f"Error when parsing .toml config file: {toml_path}" + ) from e + + if "vak" not in config_dict: + raise ValueError( + "Toml file does not contain a top-level table named `vak`. " + "Please see example configuration files here:\n" + "https://github.com/vocalpy/vak/tree/main/doc/toml" + ) + + # Next line, convert TOMLDocument returned by tomlkit.load to a dict. + # We need this so we don't get a ``tomlkit.items._ConvertError`` when + # the `from_config_dict` classmethods try to add a class to a ``config_dict``, + # e.g. when :meth:`EvalConfig.from_config_dict` converts the ``spect_params`` + # key-value pairs to a :class:`vak.config.SpectParamsConfig` instance + # and then assigns it to the ``spect_params`` key. + # We would get this error if we just return the result of :func:`tomlkit.load`, + # which is a `tomlkit.TOMLDocument` that tries to ensure that everything is valid toml. + return _tomlkit_to_popo(config_dict)["vak"] diff --git a/src/vak/config/model.py b/src/vak/config/model.py index d1d36643b..aaf920866 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -1,8 +1,9 @@ -from __future__ import annotations +"""Class representing the model table of a toml configuration file.""" -import pathlib +from __future__ import annotations -import toml +from attrs import asdict, define, field +from attrs.validators import instance_of from .. import models @@ -14,77 +15,88 @@ ] -def config_from_toml_dict(toml_dict: dict, model_name: str) -> dict: - """Get configuration for a model from a .toml configuration file - loaded into a ``dict``. - - Parameters - ---------- - toml_dict : dict - Configuration from a .toml file, loaded into a dictionary. - model_name : str - Name of a model, specified as the ``model`` option in a table - (such as TRAIN or PREDICT), - that should have its own corresponding table - specifying its configuration: hyperparameters such as learning rate, etc. - - Returns - ------- - model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - """ - if model_name not in models.registry.MODEL_NAMES: - raise ValueError( - f"Invalid model name: {model_name}.\nValid model names are: {models.registry.MODEL_NAMES}" - ) - - try: - model_config = toml_dict[model_name] - except KeyError as e: - raise ValueError( - f"A config section specifies the model name '{model_name}', " - f"but there is no section named '{model_name}' in the config." - ) from e - - # check if config declares parameters for required attributes; - # if not, just put an empty dict that will get passed as the "kwargs" - for attr in MODEL_TABLES: - if attr not in model_config: - model_config[attr] = {} - - return model_config +@define +class ModelConfig: + """Class representing the model table of a toml configuration file. - -def config_from_toml_path( - toml_path: str | pathlib.Path, model_name: str -) -> dict: - """Get configuration for a model from a .toml configuration file, - given the path to the file. - - Parameters + Attributes ---------- - toml_path : str, Path - to configuration file in .toml format - model_name : str - of str, i.e. names of models specified by a section - (such as TRAIN or PREDICT) that should each have corresponding sections - specifying their configuration: hyperparameters such as learning rate, etc. - - Returns - ------- - model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. + name : str + network : dict + Keyword arguments for the network class, + or a :class:`dict` of ``dict``s mapping + network names to keyword arguments. + optimizer: dict + Keyword arguments for the optimizer class. + loss : dict + Keyword arguments for the class representing the loss function. + metrics: dict + A :class:`dict` of ``dict``s mapping + metric names to keyword arguments. """ - toml_path = pathlib.Path(toml_path) - if not toml_path.is_file(): - raise FileNotFoundError( - f"File not found, or not recognized as a file: {toml_path}" - ) - with toml_path.open("r") as fp: - config_dict = toml.load(fp) - return config_from_toml_dict(config_dict, model_name) + name: str + network: dict = field(validator=instance_of(dict)) + optimizer: dict = field(validator=instance_of(dict)) + loss: dict = field(validator=instance_of(dict)) + metrics: dict = field(validator=instance_of(dict)) + + @classmethod + def from_config_dict(cls, config_dict: dict): + """Return :class:`ModelConfig` instance from a :class:`dict`. + + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using a top-level table key, + followed by key ``'model'``. + E.g., ``config_dict['train']['model']` or + ``config_dict['predict']['model']``. + + Examples + -------- + config_dict = vak.config.parse.from_toml_path(toml_path) + model_config = vak.config.Model.from_config_dict(config_dict['train']) + """ + model_name = list(config_dict.keys()) + if len(model_name) == 0: + raise ValueError( + "Did not find a single key in `config_dict` corresponding to model name. " + f"Instead found no keys. Config dict:\n{config_dict}\n" + "A configuration file should specify a single model per top-level table." + ) + if len(model_name) > 1: + raise ValueError( + "Did not find a single key in `config_dict` corresponding to model name. " + f"Instead found multiple keys: {model_name}.\nConfig dict:\n{config_dict}.\n" + "A configuration file should specify a single model per top-level table." + ) + model_name = model_name[0] + MODEL_NAMES = list(models.registry.MODEL_NAMES) + if model_name not in MODEL_NAMES: + raise ValueError( + f"Model name not found in registry: {model_name}\n" + f"Model names in registry:\n{MODEL_NAMES}" + ) + model_config = config_dict[model_name] + if not all(key in MODEL_TABLES for key in model_config.keys()): + invalid_keys = ( + key for key in model_config.keys() if key not in MODEL_TABLES + ) + raise ValueError( + f"The following sub-tables in the model config are not valid: {invalid_keys}\n" + f"Valid sub-table names are: {MODEL_TABLES}" + ) + # for any tables not specified, default to empty dict so we can still use ``**`` operator on it + for model_table in MODEL_TABLES: + if model_table not in config_dict: + model_config[model_table] = {} + return cls(name=model_name, **model_config) + + def asdict(self): + """Convert this :class:`ModelConfig` instance + to a :class:`dict` that can be passed + into functions that take a ``model_config`` argument, + like :func:`vak.train` and :func:`vak.predict`. + """ + return asdict(self) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py deleted file mode 100644 index 295a1a537..000000000 --- a/src/vak/config/parse.py +++ /dev/null @@ -1,202 +0,0 @@ -from pathlib import Path - -import toml -from toml.decoder import TomlDecodeError - -from .config import Config -from .eval import EvalConfig -from .learncurve import LearncurveConfig -from .predict import PredictConfig -from .prep import PrepConfig -from .spect_params import SpectParamsConfig -from .train import TrainConfig -from .validators import are_options_valid, are_sections_valid - -SECTION_CLASSES = { - "EVAL": EvalConfig, - "LEARNCURVE": LearncurveConfig, - "PREDICT": PredictConfig, - "PREP": PrepConfig, - "SPECT_PARAMS": SpectParamsConfig, - "TRAIN": TrainConfig, -} - -REQUIRED_OPTIONS = { - "EVAL": [ - "checkpoint_path", - "output_dir", - "model", - ], - "LEARNCURVE": [ - "model", - "root_results_dir", - ], - "PREDICT": [ - "checkpoint_path", - "model", - ], - "PREP": [ - "data_dir", - "output_dir", - ], - "SPECT_PARAMS": None, - "TRAIN": [ - "model", - "root_results_dir", - ], -} - - -def parse_config_section(config_toml, section_name, toml_path=None): - """parse section of config.toml file - - Parameters - ---------- - config_toml : dict - containing config.toml file already loaded by parse function - section_name : str - name of section from configuration - file that should be parsed - toml_path : str - path to a configuration file in TOML format. Default is None. - Used for error messages if specified. - - Returns - ------- - config : vak.config section class - instance of class that represents section of config.toml file, - e.g. PredictConfig for 'PREDICT' section - """ - section = dict(config_toml[section_name].items()) - - required_options = REQUIRED_OPTIONS[section_name] - if required_options is not None: - for required_option in required_options: - if required_option not in section: - if toml_path: - err_msg = ( - f"the '{required_option}' option is required but was not found in the " - f"{section_name} section of the config.toml file: {toml_path}" - ) - else: - err_msg = ( - f"the '{required_option}' option is required but was not found in the " - f"{section_name} section of the toml config" - ) - raise KeyError(err_msg) - return SECTION_CLASSES[section_name](**section) - - -def _validate_sections_arg_convert_list(sections): - if isinstance(sections, str): - sections = [sections] - elif isinstance(sections, list): - if not all( - [isinstance(section_name, str) for section_name in sections] - ): - raise ValueError( - "all section names in 'sections' should be strings" - ) - if not all( - [ - section_name in list(SECTION_CLASSES.keys()) - for section_name in sections - ] - ): - raise ValueError( - "all section names in 'sections' should be valid names of sections. " - f"Values for 'sections were: {sections}.\n" - f"Valid section names are: {list(SECTION_CLASSES.keys())}" - ) - return sections - - -def from_toml(config_toml, toml_path=None, sections=None): - """load a TOML configuration file - - Parameters - ---------- - config_toml : dict - Python ``dict`` containing a .toml configuration file, - parsed by the ``toml`` library. - toml_path : str, Path - path to a configuration file in TOML format. Default is None. - Not required, used only to make any error messages clearer. - sections : str, list - name of section or sections from configuration - file that should be parsed. Can be a string - (single section) or list of strings (multiple - sections). Default is None, - in which case all are validated and parsed. - - Returns - ------- - config : vak.config.parse.Config - instance of Config class, whose attributes correspond to - sections in a config.toml file. - """ - are_sections_valid(config_toml, toml_path) - - sections = _validate_sections_arg_convert_list(sections) - - config_dict = {} - if sections is None: - sections = list( - SECTION_CLASSES.keys() - ) # i.e., parse all sections, except model - for section_name in sections: - if section_name in config_toml: - are_options_valid(config_toml, section_name, toml_path) - config_dict[section_name.lower()] = parse_config_section( - config_toml, section_name, toml_path - ) - - return Config(**config_dict) - - -def _load_toml_from_path(toml_path): - """helper function to load toml config file, - factored out to use in other modules when needed - - checks if ``toml_path`` exists before opening, - and tries to give a clear message if an error occurs when parsing""" - toml_path = Path(toml_path) - if not toml_path.is_file(): - raise FileNotFoundError(f".toml config file not found: {toml_path}") - - try: - with toml_path.open("r") as fp: - config_toml = toml.load(fp) - except TomlDecodeError as e: - raise Exception( - f"Error when parsing .toml config file: {toml_path}" - ) from e - - return config_toml - - -def from_toml_path(toml_path, sections=None): - """parse a TOML configuration file - - Parameters - ---------- - toml_path : str, Path - path to a configuration file in TOML format. - Parsed by ``toml`` library, then converted to an - instance of ``vak.config.parse.Config`` by - calling ``vak.parse.from_toml`` - sections : str, list - name of section or sections from configuration - file that should be parsed. Can be a string - (single section) or list of strings (multiple - sections). Default is None, - in which case all are validated and parsed. - - Returns - ------- - config : vak.config.parse.Config - instance of Config class, whose attributes correspond to - sections in a config.toml file. - """ - config_toml = _load_toml_from_path(toml_path) - return from_toml(config_toml, toml_path, sections) diff --git a/src/vak/config/predict.py b/src/vak/config/predict.py index 852605165..8803d9317 100644 --- a/src/vak/config/predict.py +++ b/src/vak/config/predict.py @@ -1,30 +1,45 @@ -"""parses [PREDICT] section of config""" +"""Class that represents ``[vak.predict]`` table of configuration file.""" + +from __future__ import annotations + import os from pathlib import Path -import attr from attr import converters, validators from attr.validators import instance_of +from attrs import define, field from ..common import device from ..common.converters import expanded_user_path -from .validators import is_valid_model_name +from .dataset import DatasetConfig +from .model import ModelConfig + +REQUIRED_KEYS = ( + "checkpoint_path", + "dataset", + "model", +) -@attr.s +@define class PredictConfig: - """class that represents [PREDICT] section of config.toml file + """Class that represents ``[vak.predict]`` table of configuration file. Attributes ---------- - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. - checkpoint_path : str - path to directory with checkpoint files saved by Torch, to reload model - labelmap_path : str - path to 'labelmap.json' file. - model : str - Model name, e.g., ``model = "TweetyNet"`` + dataset : vak.config.DatasetConfig + The dataset to use: the path to it, + and optionally a path to a file representing splits, + and the name, if it is a built-in dataset. + Must be an instance of :class:`vak.config.DatasetConfig`. + checkpoint_path : str + path to directory with checkpoint files saved by Torch, to reload model + labelmap_path : str + path to 'labelmap.json' file. + model : vak.config.ModelConfig + The model to use: its name, + and the parameters to configure it. + Must be an instance of :class:`vak.config.ModelConfig` batch_size : int number of samples per batch presented to models during training. num_workers : int @@ -68,64 +83,65 @@ class PredictConfig: spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, and the network is `TweetyNet`, then the net output file will be `gy6or6_032312_081416.tweetynet.output.npz`. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. """ # required, external files - checkpoint_path = attr.ib(converter=expanded_user_path) - labelmap_path = attr.ib(converter=expanded_user_path) + checkpoint_path = field(converter=expanded_user_path) + labelmap_path = field(converter=expanded_user_path) # required, model / dataloader - model = attr.ib( - validator=[instance_of(str), is_valid_model_name], + model = field( + validator=instance_of(ModelConfig), ) - batch_size = attr.ib(converter=int, validator=instance_of(int)) - - # dataset_path is actually 'required' but we can't enforce that here because cli.prep looks at - # what sections are defined to figure out where to add dataset_path after it creates the csv - dataset_path = attr.ib( - converter=converters.optional(expanded_user_path), - default=None, + batch_size = field(converter=int, validator=instance_of(int)) + dataset: DatasetConfig = field( + validator=instance_of(DatasetConfig), ) # optional, transform - spect_scaler_path = attr.ib( + spect_scaler_path = field( converter=converters.optional(expanded_user_path), default=None, ) # optional, data loader - num_workers = attr.ib(validator=instance_of(int), default=2) - device = attr.ib(validator=instance_of(str), default=device.get_default()) + num_workers = field(validator=instance_of(int), default=2) + device = field(validator=instance_of(str), default=device.get_default()) - annot_csv_filename = attr.ib( + annot_csv_filename = field( validator=validators.optional(instance_of(str)), default=None ) - output_dir = attr.ib( + output_dir = field( converter=expanded_user_path, default=Path(os.getcwd()), ) - min_segment_dur = attr.ib( + min_segment_dur = field( validator=validators.optional(instance_of(float)), default=None ) - majority_vote = attr.ib(validator=instance_of(bool), default=True) - save_net_outputs = attr.ib(validator=instance_of(bool), default=False) + majority_vote = field(validator=instance_of(bool), default=True) + save_net_outputs = field(validator=instance_of(bool), default=False) - transform_params = attr.ib( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) + @classmethod + def from_config_dict(cls, config_dict: dict) -> PredictConfig: + """Return :class:`PredictConfig` instance from a :class:`dict`. - dataset_params = attr.ib( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using key ``predict``, + i.e., ``PredictConfig.from_config_dict(config_dict['predict'])``.""" + for required_key in REQUIRED_KEYS: + if required_key not in config_dict: + raise KeyError( + "The `[vak.eval]` table in a configuration file requires " + f"the option '{required_key}', but it was not found " + "when loading the configuration file into a Python dictionary. " + "Please check that the configuration file is formatted correctly." + ) + config_dict["dataset"] = DatasetConfig.from_config_dict( + config_dict["dataset"] + ) + config_dict["model"] = ModelConfig.from_config_dict( + config_dict["model"] + ) + return cls(**config_dict) diff --git a/src/vak/config/prep.py b/src/vak/config/prep.py index 7481d8cc2..3023ce204 100644 --- a/src/vak/config/prep.py +++ b/src/vak/config/prep.py @@ -1,13 +1,16 @@ -"""parses [PREP] section of config""" +"""Class and functions for ``[vak.prep]`` table of configuration file.""" + +from __future__ import annotations + import inspect -import attr import dask.bag -from attr import converters, validators -from attr.validators import instance_of +from attrs import converters, define, field, validators +from attrs.validators import instance_of from .. import prep from ..common.converters import expanded_user_path, labelset_to_set +from .spect_params import SpectParamsConfig from .validators import is_annot_format, is_audio_format, is_spect_format @@ -60,9 +63,15 @@ def are_valid_dask_bag_kwargs(instance, attribute, value): ) -@attr.s +REQUIRED_KEYS = ( + "data_dir", + "output_dir", +) + + +@define class PrepConfig: - """class to represent [PREP] section of config.toml file + """Class that represents ``[vak.prep]`` table of configuration file. Attributes ---------- @@ -84,6 +93,11 @@ class PrepConfig: spect_format : str format of files containg spectrograms as 2-d matrices. One of {'mat', 'npy'}. + spect_params: vak.config.SpectParamsConfig, optional + Parameters for Short-Time Fourier Transform and post-processing + of spectrograms. + Instance of :class:`vak.config.SpectParamsConfig` class. + Optional, default is None. annot_format : str format of annotations. Any format that can be used with the crowsetta library is valid. @@ -127,10 +141,10 @@ class PrepConfig: Default is None. Required if config file has a learncurve section. """ - data_dir = attr.ib(converter=expanded_user_path) - output_dir = attr.ib(converter=expanded_user_path) + data_dir = field(converter=expanded_user_path) + output_dir = field(converter=expanded_user_path) - dataset_type = attr.ib(validator=instance_of(str)) + dataset_type = field(validator=instance_of(str)) @dataset_type.validator def is_valid_dataset_type(self, attribute, value): @@ -140,7 +154,7 @@ def is_valid_dataset_type(self, attribute, value): f"Valid dataset types are: {prep.constants.DATASET_TYPES}" ) - input_type = attr.ib(validator=instance_of(str)) + input_type = field(validator=instance_of(str)) @input_type.validator def is_valid_input_type(self, attribute, value): @@ -149,49 +163,53 @@ def is_valid_input_type(self, attribute, value): f"Invalid input type: {value}. Must be one of: {prep.constants.INPUT_TYPES}" ) - audio_format = attr.ib( + audio_format = field( validator=validators.optional(is_audio_format), default=None ) - spect_format = attr.ib( + spect_format = field( validator=validators.optional(is_spect_format), default=None ) - annot_file = attr.ib( + spect_params = field( + validator=validators.optional(instance_of(SpectParamsConfig)), + default=None, + ) + annot_file = field( converter=converters.optional(expanded_user_path), default=None, ) - annot_format = attr.ib( + annot_format = field( validator=validators.optional(is_annot_format), default=None ) - labelset = attr.ib( + labelset = field( converter=converters.optional(labelset_to_set), validator=validators.optional(instance_of(set)), default=None, ) - audio_dask_bag_kwargs = attr.ib( + audio_dask_bag_kwargs = field( validator=validators.optional(are_valid_dask_bag_kwargs), default=None ) - train_dur = attr.ib( + train_dur = field( converter=converters.optional(duration_from_toml_value), validator=validators.optional(is_valid_duration), default=None, ) - val_dur = attr.ib( + val_dur = field( converter=converters.optional(duration_from_toml_value), validator=validators.optional(is_valid_duration), default=None, ) - test_dur = attr.ib( + test_dur = field( converter=converters.optional(duration_from_toml_value), validator=validators.optional(is_valid_duration), default=None, ) - train_set_durs = attr.ib( + train_set_durs = field( validator=validators.optional(instance_of(list)), default=None ) - num_replicates = attr.ib( + num_replicates = field( validator=validators.optional(instance_of(int)), default=None ) @@ -203,3 +221,26 @@ def __attrs_post_init__(self): raise ValueError( "must specify either audio_format or spect_format" ) + + @classmethod + def from_config_dict(cls, config_dict: dict) -> PrepConfig: + """Return :class:`PrepConfig` instance from a :class:`dict`. + + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using key ``prep``, + i.e., ``PrepConfig.from_config_dict(config_dict['prep'])``.""" + for required_key in REQUIRED_KEYS: + if required_key not in config_dict: + raise KeyError( + "The `[vak.prep]` table in a configuration file requires " + f"the key '{required_key}', but it was not found " + "when loading the configuration file into a Python dictionary. " + "Please check that the configuration file is formatted correctly." + ) + if "spect_params" in config_dict: + config_dict["spect_params"] = SpectParamsConfig( + **config_dict["spect_params"] + ) + return cls(**config_dict) diff --git a/src/vak/config/spect_params.py b/src/vak/config/spect_params.py index 4a61942a6..b570f9e7c 100644 --- a/src/vak/config/spect_params.py +++ b/src/vak/config/spect_params.py @@ -1,4 +1,5 @@ """parses [SPECT_PARAMS] section of config""" + import attr from attr import converters, validators from attr.validators import instance_of diff --git a/src/vak/config/train.py b/src/vak/config/train.py index 034a110b4..4c997a8c1 100644 --- a/src/vak/config/train.py +++ b/src/vak/config/train.py @@ -1,23 +1,31 @@ -"""parses [TRAIN] section of config""" -import attr -from attr import converters, validators -from attr.validators import instance_of +"""Class that represents ``[vak.train]`` table of configuration file.""" + +from attrs import converters, define, field, validators +from attrs.validators import instance_of from ..common import device from ..common.converters import bool_from_str, expanded_user_path -from .validators import is_valid_model_name +from .dataset import DatasetConfig +from .model import ModelConfig + +REQUIRED_KEYS = ("dataset", "model", "root_results_dir") -@attr.s +@define class TrainConfig: - """class that represents [TRAIN] section of config.toml file + """Class that represents ``[vak.train]`` table of configuration file. Attributes ---------- - model : str - Model name, e.g., ``model = "TweetyNet"`` - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. + model : vak.config.ModelConfig + The model to use: its name, + and the parameters to configure it. + Must be an instance of :class:`vak.config.ModelConfig` + dataset : vak.config.DatasetConfig + The dataset to use: the path to it, + and optionally a path to a file representing splits, + and the name, if it is a built-in dataset. + Must be an instance of :class:`vak.config.DatasetConfig`. num_epochs : int number of training epochs. One epoch = one iteration through the entire training set. @@ -64,82 +72,78 @@ class TrainConfig: """ # required - model = attr.ib( - validator=[instance_of(str), is_valid_model_name], + model = field( + validator=instance_of(ModelConfig), ) - num_epochs = attr.ib(converter=int, validator=instance_of(int)) - batch_size = attr.ib(converter=int, validator=instance_of(int)) - root_results_dir = attr.ib(converter=expanded_user_path) - - # optional - # dataset_path is actually 'required' but we can't enforce that here because cli.prep looks at - # what sections are defined to figure out where to add dataset_path after it creates the csv - dataset_path = attr.ib( - converter=converters.optional(expanded_user_path), - default=None, + num_epochs = field(converter=int, validator=instance_of(int)) + batch_size = field(converter=int, validator=instance_of(int)) + root_results_dir = field(converter=expanded_user_path) + dataset: DatasetConfig = field( + validator=instance_of(DatasetConfig), ) - results_dirname = attr.ib( + results_dirname = field( converter=converters.optional(expanded_user_path), default=None, ) - normalize_spectrograms = attr.ib( + normalize_spectrograms = field( converter=bool_from_str, validator=validators.optional(instance_of(bool)), default=False, ) - num_workers = attr.ib(validator=instance_of(int), default=2) - device = attr.ib(validator=instance_of(str), default=device.get_default()) - shuffle = attr.ib( + num_workers = field(validator=instance_of(int), default=2) + device = field(validator=instance_of(str), default=device.get_default()) + shuffle = field( converter=bool_from_str, validator=instance_of(bool), default=True ) - val_step = attr.ib( + val_step = field( converter=converters.optional(int), validator=validators.optional(instance_of(int)), default=None, ) - ckpt_step = attr.ib( + ckpt_step = field( converter=converters.optional(int), validator=validators.optional(instance_of(int)), default=None, ) - patience = attr.ib( + patience = field( converter=converters.optional(int), validator=validators.optional(instance_of(int)), default=None, ) - checkpoint_path = attr.ib( + checkpoint_path = field( converter=converters.optional(expanded_user_path), default=None, ) - spect_scaler_path = attr.ib( + spect_scaler_path = field( converter=converters.optional(expanded_user_path), default=None, ) - train_transform_params = attr.ib( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) + @classmethod + def from_config_dict(cls, config_dict: dict) -> "TrainConfig": + """Return :class:`TrainConfig` instance from a :class:`dict`. - train_dataset_params = attr.ib( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) - - val_transform_params = attr.ib( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) - - val_dataset_params = attr.ib( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using key ``prep``, + i.e., ``TrainConfig.from_config_dict(config_dict['train'])``.""" + for required_key in REQUIRED_KEYS: + if required_key not in config_dict: + raise KeyError( + "The `[vak.train]` table in a configuration file requires " + f"the option '{required_key}', but it was not found " + "when loading the configuration file into a Python dictionary. " + "Please check that the configuration file is formatted correctly." + ) + config_dict["model"] = ModelConfig.from_config_dict( + config_dict["model"] + ) + config_dict["dataset"] = DatasetConfig.from_config_dict( + config_dict["dataset"] + ) + return cls(**config_dict) diff --git a/src/vak/config/valid.toml b/src/vak/config/valid-version-1.0.toml similarity index 68% rename from src/vak/config/valid.toml rename to src/vak/config/valid-version-1.0.toml index 11cd535f5..ded6aa6ae 100644 --- a/src/vak/config/valid.toml +++ b/src/vak/config/valid-version-1.0.toml @@ -5,7 +5,7 @@ # Options should be in the same order they are defined for the # attrs-based class that represents the config, for easy comparison # when changing that class + this file. -[PREP] +[vak.prep] data_dir = './tests/test_data/cbins/gy6or6/032312' output_dir = './tests/test_data/prep/learncurve' dataset_type = 'frame_classification' @@ -22,7 +22,7 @@ test_dur = 30 train_set_durs = [ 4.5, 6.0 ] num_replicates = 2 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000 ] @@ -33,10 +33,8 @@ freqbins_key = 'f' timebins_key = 't' audio_path_key = 'audio_path' -[TRAIN] -model = 'TweetyNet' +[vak.train] root_results_dir = './tests/test_data/results/train' -dataset_path = 'tests/test_data/prep/train/032312_prep_191224_225912.csv' num_workers = 4 device = 'cuda' batch_size = 11 @@ -49,27 +47,33 @@ patience = 4 results_dir_made_by_main_script = '/some/path/to/learncurve/' checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' spect_scaler_path = '/home/user/results_181014_194418/spect_scaler' -train_transform_params = {'resize' = 128} -train_dataset_params = {'window_size' = 80} -val_transform_params = {'resize' = 128} -val_dataset_params = {'window_size' = 80} -[EVAL] -dataset_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' +[vak.train.dataset] +name = 'IntlDistributedSongbirdConsortiumPack' +path = 'tests/test_data/prep/train/032312_prep_191224_225912.csv' +splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' +params = {window_size = 2000} + +[vak.train.model.TweetyNet] + +[vak.eval] checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' labelmap_path = '/home/user/results_181014_194418/labelmap.json' output_dir = './tests/test_data/prep/learncurve' -model = 'TweetyNet' batch_size = 11 num_workers = 4 device = 'cuda' spect_scaler_path = '/home/user/results_181014_194418/spect_scaler' post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} -transform_params = {'resize' = 128} -dataset_params = {'window_size' = 80} -[LEARNCURVE] -model = 'TweetyNet' +[vak.eval.dataset] +name = 'IntlDistributedSongbirdConsortiumPack' +path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' +splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' + +[vak.eval.model.TweetyNet] + +[vak.learncurve] root_results_dir = './tests/test_data/results/learncurve' batch_size = 11 num_epochs = 2 @@ -78,23 +82,24 @@ shuffle = true val_step = 1 ckpt_step = 1 patience = 4 -dataset_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' results_dir_made_by_main_script = '/some/path/to/learncurve/' post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} num_workers = 4 device = 'cuda' -train_transform_params = {'resize' = 128} -train_dataset_params = {'window_size' = 80} -val_transform_params = {'resize' = 128} -val_dataset_params = {'window_size' = 80} -[PREDICT] -dataset_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' +[vak.learncurve.dataset] +name = 'IntlDistributedSongbirdConsortiumPack' +path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' +splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' +params = {window_size = 2000} + +[vak.learncurve.model.TweetyNet] + +[vak.predict] checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' labelmap_path = '/home/user/results_181014_194418/labelmap.json' annot_csv_filename = '032312_prep_191224_225910.annot.csv' output_dir = './tests/test_data/prep/learncurve' -model = 'TweetyNet' batch_size = 11 num_workers = 4 device = 'cuda' @@ -102,5 +107,11 @@ spect_scaler_path = '/home/user/results_181014_194418/spect_scaler' min_segment_dur = 0.004 majority_vote = false save_net_outputs = false -transform_params = {'resize' = 128} -dataset_params = {'window_size' = 80} + +[vak.predict.dataset] +name = 'IntlDistributedSongbirdConsortiumPack' +path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' +splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' +params = {window_size = 2000} + +[vak.predict.model.TweetyNet] \ No newline at end of file diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index 71d757d10..628656a51 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -1,7 +1,8 @@ """validators used by attrs-based classes and by vak.parse.parse_config""" -from pathlib import Path -import toml +import pathlib + +import tomlkit from .. import models from ..common import constants @@ -9,7 +10,7 @@ def is_a_directory(instance, attribute, value): """check if given path is a directory""" - if not Path(value).is_dir(): + if not pathlib.Path(value).is_dir(): raise NotADirectoryError( f"Value specified for {attribute.name} of {type(instance)} not recognized as a directory:\n" f"{value}" @@ -18,7 +19,7 @@ def is_a_directory(instance, attribute, value): def is_a_file(instance, attribute, value): """check if given path is a file""" - if not Path(value).is_file(): + if not pathlib.Path(value).is_file(): raise FileNotFoundError( f"Value specified for {attribute.name} of {type(instance)} not recognized as a file:\n" f"{value}" @@ -34,13 +35,13 @@ def is_valid_model_name(instance, attribute, value: str) -> None: def is_audio_format(instance, attribute, value): - """check if valid audio format""" + """Check if valid audio format""" if value not in constants.VALID_AUDIO_FORMATS: raise ValueError(f"{value} is not a valid format for audio files") def is_annot_format(instance, attribute, value): - """check if valid annotation format""" + """Check if valid annotation format""" if value not in constants.VALID_ANNOT_FORMATS: raise ValueError( f"{value} is not a valid format for annotation files.\n" @@ -49,7 +50,7 @@ def is_annot_format(instance, attribute, value): def is_spect_format(instance, attribute, value): - """check if valid format for spectrograms""" + """Check if valid format for spectrograms""" if value not in constants.VALID_SPECT_FORMATS: raise ValueError( f"{value} is not a valid format for spectrogram files.\n" @@ -57,73 +58,117 @@ def is_spect_format(instance, attribute, value): ) -CONFIG_DIR = Path(__file__).parent -VALID_TOML_PATH = CONFIG_DIR.joinpath("valid.toml") +CONFIG_DIR = pathlib.Path(__file__).parent +VALID_TOML_PATH = CONFIG_DIR.joinpath("valid-version-1.0.toml") with VALID_TOML_PATH.open("r") as fp: - VALID_DICT = toml.load(fp) -VALID_SECTIONS = list(VALID_DICT.keys()) -VALID_OPTIONS = { - section: list(options.keys()) for section, options in VALID_DICT.items() + VALID_DICT = tomlkit.load(fp)["vak"] +VALID_TOP_LEVEL_TABLES = list(VALID_DICT.keys()) +VALID_KEYS = { + table_name: list(table_config_dict.keys()) + for table_name, table_config_dict in VALID_DICT.items() } -def are_sections_valid(config_dict, toml_path=None): - sections = list(config_dict.keys()) +def are_tables_valid(config_dict, toml_path=None): + """Validate top-level tables in class:`dict`. + + This function expects the ``config_dict`` + returned by :func:`vak.config.load._load_from_toml_path`. + """ + tables = list(config_dict.keys()) from ..cli.cli import CLI_COMMANDS # avoid circular import cli_commands_besides_prep = [ command for command in CLI_COMMANDS if command != "prep" ] - sections_that_are_commands_besides_prep = [ - section - for section in sections - if section.lower() in cli_commands_besides_prep + tables_that_are_commands_besides_prep = [ + table for table in tables if table in cli_commands_besides_prep ] - if len(sections_that_are_commands_besides_prep) == 0: + if len(tables_that_are_commands_besides_prep) == 0: raise ValueError( - "did not find a section related to a vak command in config besides `prep`.\n" - f"Sections in config were: {sections}" + "Did not find a table related to a vak command in config besides `prep`.\n" + f"Sections in config were: {tables}\n" + "Please see example toml configuration files here: https://github.com/vocalpy/vak/tree/main/doc/toml" ) - if len(sections_that_are_commands_besides_prep) > 1: + if len(tables_that_are_commands_besides_prep) > 1: raise ValueError( - "found multiple sections related to a vak command in config besides `prep`.\n" - f"Those sections are: {sections_that_are_commands_besides_prep}. " - f"Please use just one command besides `prep` per .toml configuration file" + "Found multiple tables related to a vak command in config besides `prep`.\n" + f"Those tables are: {tables_that_are_commands_besides_prep}. " + f"Please use just one command besides `prep` per .toml configuration file.\n" + "See example toml configuration files here: https://github.com/vocalpy/vak/tree/main/doc/toml" ) - MODEL_NAMES = list(models.registry.MODEL_NAMES) - # add model names to valid sections so users can define model config in sections - valid_sections = VALID_SECTIONS + MODEL_NAMES - for section in sections: - if ( - section not in valid_sections - and f"{section}Model" not in valid_sections - ): + for table in tables: + if table not in VALID_TOP_LEVEL_TABLES: if toml_path: err_msg = ( - f"section defined in {toml_path} is not valid: {section}" + f"Top-level table defined in {toml_path} is not valid: {table}\n" + f"Valid top-level tables are: {VALID_TOP_LEVEL_TABLES}\n" + "Please see example toml configuration files here: " + "https://github.com/vocalpy/vak/tree/main/doc/toml" ) else: err_msg = ( - f"section defined in toml config is not valid: {section}" + f"Table defined in toml config is not valid: {table}\n" + f"Valid top-level tables are: {VALID_TOP_LEVEL_TABLES}\n" + "Please see example toml configuration files here: " + "https://github.com/vocalpy/vak/tree/main/doc/toml" ) raise ValueError(err_msg) -def are_options_valid(config_dict, section, toml_path=None): - user_options = set(config_dict[section].keys()) - valid_options = set(VALID_OPTIONS[section]) - if not user_options.issubset(valid_options): - invalid_options = user_options - valid_options +def are_keys_valid( + config_dict: dict, + table_name: str, + toml_path: str | pathlib.Path | None = None, +) -> None: + """Given a :class:`dict` containing the *entire* configuration loaded from a toml file, + validate the key names for a specific top-level table, e.g. ``vak.train`` or ``vak.predict`` + """ + table_keys = set(config_dict[table_name].keys()) + valid_keys = set(VALID_KEYS[table_name]) + if not table_keys.issubset(valid_keys): + invalid_keys = table_keys - valid_keys + if toml_path: + err_msg = ( + f"The following keys from '{table_name}' table in " + f"the config file '{toml_path.name}' are not valid:\n{invalid_keys}" + ) + else: + err_msg = ( + f"The following keys from '{table_name}' table in " + f"the toml config are not valid:\n{invalid_keys}" + ) + raise ValueError(err_msg) + + +def are_table_keys_valid( + table_config_dict: dict, + table_name: str, + toml_path: str | pathlib.Path | None = None, +) -> None: + """Given a :class:`dict` containing the configuration for a *specific* top-level table, + loaded from a toml file, validate the key names for that table, + e.g. ``vak.train`` or ``vak.predict``. + + This function assumes ``table_config_dict`` comes from the entire ``config_dict`` + returned by :func:`vak.config.parse.from_toml_path`, accessed using the table name as a key, + unlike :func:`are_keys_valid`. This function is used by the ``from_config_dict`` + classmethod of the top-level tables. + """ + table_keys = set(table_config_dict.keys()) + valid_keys = set(VALID_KEYS[table_name]) + if not table_keys.issubset(valid_keys): + invalid_keys = table_keys - valid_keys if toml_path: err_msg = ( - f"the following options from {section} section in " - f"the config file '{toml_path.name}' are not valid:\n{invalid_options}" + f"The following keys from '{table_name}' table in " + f"the config file '{toml_path.name}' are not valid:\n{invalid_keys}" ) else: err_msg = ( - f"the following options from {section} section in " - f"the toml config are not valid:\n{invalid_options}" + f"The following keys from '{table_name}' table in " + f"the toml config are not valid:\n{invalid_keys}" ) raise ValueError(err_msg) diff --git a/src/vak/datasets/frame_classification/frames_dataset.py b/src/vak/datasets/frame_classification/frames_dataset.py index 6d91ad77e..94ea8169d 100644 --- a/src/vak/datasets/frame_classification/frames_dataset.py +++ b/src/vak/datasets/frame_classification/frames_dataset.py @@ -1,6 +1,7 @@ """A dataset class used for neural network models with the frame classification task, where the source data consists of audio signals or spectrograms of varying lengths.""" + from __future__ import annotations import pathlib @@ -78,8 +79,8 @@ def __init__( sample_ids: npt.NDArray, inds_in_sample: npt.NDArray, frame_dur: float, + item_transform: Callable, subset: str | None = None, - item_transform: Callable | None = None, ): """Initialize a new instance of a FramesDataset. @@ -114,9 +115,9 @@ def __init__( If specified, this takes precedence over split. Subsets are typically taken from the training data for use when generating a learning curve. - item_transform : callable, optional - Transform applied to each item :math:`(x, y)` - returned by :meth:`FramesDataset.__getitem__`. + item_transform : callable + The transform applied to each item :math:`(x, y)` + that is returned by :meth:`FramesDataset.__getitem__`. """ from ... import ( prep, @@ -195,9 +196,9 @@ def __len__(self): def from_dataset_path( cls, dataset_path: str | pathlib.Path, + item_transform: Callable, split: str = "val", subset: str | None = None, - item_transform: Callable | None = None, ): """Make a :class:`FramesDataset` instance, given the path to a frame classification dataset. @@ -209,17 +210,18 @@ def from_dataset_path( frame classification dataset, as created by :func:`vak.prep.prep_frame_classification_dataset`. + item_transform : callable, optional + Transform applied to each item :math:`(x, y)` + returned by :meth:`FramesDataset.__getitem__`. split : str The name of a split from the dataset, one of {'train', 'val', 'test'}. + Default is "val". subset : str, optional Name of subset to use. If specified, this takes precedence over split. Subsets are typically taken from the training data for use when generating a learning curve. - item_transform : callable, optional - Transform applied to each item :math:`(x, y)` - returned by :meth:`FramesDataset.__getitem__`. Returns ------- @@ -262,6 +264,6 @@ def from_dataset_path( sample_ids, inds_in_sample, frame_dur, - subset, item_transform, + subset, ) diff --git a/src/vak/datasets/frame_classification/helper.py b/src/vak/datasets/frame_classification/helper.py index 41163cf79..d6a6f19b1 100644 --- a/src/vak/datasets/frame_classification/helper.py +++ b/src/vak/datasets/frame_classification/helper.py @@ -1,4 +1,5 @@ """Helper functions used with frame classification datasets.""" + from __future__ import annotations from ... import common diff --git a/src/vak/datasets/frame_classification/metadata.py b/src/vak/datasets/frame_classification/metadata.py index 61c7cb918..b7a532aae 100644 --- a/src/vak/datasets/frame_classification/metadata.py +++ b/src/vak/datasets/frame_classification/metadata.py @@ -2,6 +2,7 @@ associated with a frame classification dataset, as generated by :func:`vak.core.prep.frame_classification.prep_frame_classification_dataset`""" + from __future__ import annotations import json diff --git a/src/vak/datasets/frame_classification/window_dataset.py b/src/vak/datasets/frame_classification/window_dataset.py index 30fe034dd..d916a6bcc 100644 --- a/src/vak/datasets/frame_classification/window_dataset.py +++ b/src/vak/datasets/frame_classification/window_dataset.py @@ -15,6 +15,7 @@ :math:`I` determined by a ``stride`` parameter :math:`s`, :math:`I = (T - w) / s`. """ + from __future__ import annotations import pathlib @@ -173,11 +174,10 @@ def __init__( inds_in_sample: npt.NDArray, window_size: int, frame_dur: float, + item_transform: Callable, stride: int = 1, subset: str | None = None, window_inds: npt.NDArray | None = None, - transform: Callable | None = None, - target_transform: Callable | None = None, ): """Initialize a new instance of a WindowDataset. @@ -210,6 +210,9 @@ def __init__( frame_dur: float Duration of a frame, i.e., a single sample in audio or a single timebin in a spectrogram. + item_transform : callable + The transform applied to each item :math:`(x, y)` + that is returned by :meth:`WindowDataset.__getitem__`. stride : int The size of the stride used to determine which windows are included in the dataset. The default is 1. @@ -266,8 +269,7 @@ def __init__( sample_ids.shape[-1], window_size, stride ) self.window_inds = window_inds - self.transform = transform - self.target_transform = target_transform + self.item_transform = item_transform @property def duration(self): @@ -276,10 +278,10 @@ def duration(self): @property def shape(self): tmp_x_ind = 0 - one_x, _ = self.__getitem__(tmp_x_ind) + tmp_item = self.__getitem__(tmp_x_ind) # used by vak functions that need to determine size of window, # e.g. when initializing a neural network model - return one_x.shape + return tmp_item["frames"].shape def _load_frames(self, frames_path): """Helper function that loads "frames", @@ -337,12 +339,8 @@ def __getitem__(self, idx): frame_labels = frame_labels[ inds_in_sample : inds_in_sample + self.window_size # noqa: E203 ] - if self.transform: - frames = self.transform(frames) - if self.target_transform: - frame_labels = self.target_transform(frame_labels) - - return frames, frame_labels + item = self.item_transform(frames, frame_labels) + return item def __len__(self): """number of batches""" @@ -353,11 +351,10 @@ def from_dataset_path( cls, dataset_path: str | pathlib.Path, window_size: int, + item_transform: Callable, stride: int = 1, split: str = "train", subset: str | None = None, - transform: Callable | None = None, - target_transform: Callable | None = None, ): """Make a :class:`WindowDataset` instance, given the path to a frame classification dataset. @@ -440,9 +437,8 @@ def from_dataset_path( inds_in_sample, window_size, frame_dur, + item_transform, stride, subset, window_inds, - transform, - target_transform, ) diff --git a/src/vak/datasets/parametric_umap/metadata.py b/src/vak/datasets/parametric_umap/metadata.py index ac0b8a137..a821a223a 100644 --- a/src/vak/datasets/parametric_umap/metadata.py +++ b/src/vak/datasets/parametric_umap/metadata.py @@ -2,6 +2,7 @@ associated with a dimensionality reduction dataset, as generated by :func:`vak.core.prep.frame_classification.prep_dimensionality_reduction_dataset`""" + from __future__ import annotations import json diff --git a/src/vak/datasets/parametric_umap/parametric_umap.py b/src/vak/datasets/parametric_umap/parametric_umap.py index 052975d9c..d95cb0150 100644 --- a/src/vak/datasets/parametric_umap/parametric_umap.py +++ b/src/vak/datasets/parametric_umap/parametric_umap.py @@ -1,4 +1,5 @@ """A dataset class used to train Parametric UMAP models.""" + from __future__ import annotations import pathlib diff --git a/src/vak/eval/eval_.py b/src/vak/eval/eval_.py index 7f57d8f99..fa1209f1d 100644 --- a/src/vak/eval/eval_.py +++ b/src/vak/eval/eval_.py @@ -1,4 +1,5 @@ """High-level function that evaluates trained models.""" + from __future__ import annotations import logging @@ -13,16 +14,13 @@ def eval( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, checkpoint_path: str | pathlib.Path, output_dir: str | pathlib.Path, num_workers: int, labelmap_path: str | pathlib.Path | None = None, batch_size: int | None = None, - transform_params: dict | None = None, - dataset_params: dict | None = None, split: str = "test", spect_scaler_path: str | pathlib.Path = None, post_tfm_kwargs: dict | None = None, @@ -32,14 +30,12 @@ def eval( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str, pathlib.Path - Path to dataset, e.g., a csv file generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. checkpoint_path : str, pathlib.Path path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path @@ -53,14 +49,6 @@ def eval( batch_size : int, optional. Number of samples per batch fed into model. Optional, default is None. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. split : str split of dataset on which model should be evaluated. One of {'train', 'val', 'test'}. Default is 'test'. @@ -105,29 +93,28 @@ def eval( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) + model_name = model_config["name"] try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: raise ValueError( f"No model family found for the model name specified: {model_name}" ) from e + if model_family == "FrameClassificationModel": eval_frame_classification_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, checkpoint_path=checkpoint_path, labelmap_path=labelmap_path, output_dir=output_dir, num_workers=num_workers, - transform_params=transform_params, - dataset_params=dataset_params, split=split, spect_scaler_path=spect_scaler_path, device=device, @@ -135,15 +122,12 @@ def eval( ) elif model_family == "ParametricUMAPModel": eval_parametric_umap_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, checkpoint_path=checkpoint_path, output_dir=output_dir, batch_size=batch_size, num_workers=num_workers, - transform_params=transform_params, - dataset_params=dataset_params, split=split, device=device, ) diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index 531c55d6c..9757287e8 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -1,4 +1,5 @@ """Function that evaluates trained models in the frame classification family.""" + from __future__ import annotations import json @@ -20,15 +21,12 @@ def eval_frame_classification_model( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, checkpoint_path: str | pathlib.Path, labelmap_path: str | pathlib.Path, output_dir: str | pathlib.Path, num_workers: int, - transform_params: dict | None = None, - dataset_params: dict | None = None, split: str = "test", spect_scaler_path: str | pathlib.Path = None, post_tfm_kwargs: dict | None = None, @@ -38,14 +36,12 @@ def eval_frame_classification_model( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str, pathlib.Path - Path to dataset, e.g., a csv file generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. checkpoint_path : str, pathlib.Path Path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path @@ -55,14 +51,6 @@ def eval_frame_classification_model( num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. split : str Split of dataset on which model should be evaluated. One of {'train', 'val', 'test'}. Default is 'test'. @@ -107,7 +95,7 @@ def eval_frame_classification_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -141,19 +129,30 @@ def eval_frame_classification_model( logger.info(f"loading labelmap from path: {labelmap_path}") with labelmap_path.open("r") as f: labelmap = json.load(f) - if transform_params is None: - transform_params = {} - transform_params.update({"spect_standardizer": spect_standardizer}) + + model_name = model_config["name"] + # TODO: move this into datapipe once each datapipe uses a fixed set of transforms + # that will require adding `spect_standardizer`` as a parameter to the datapipe, + # maybe rename to `frames_standardizer`? + try: + window_size = dataset_config["params"]["window_size"] + except KeyError as e: + raise KeyError( + f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table " + f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}" + ) from e + transform_params = { + "spect_standardizer": spect_standardizer, + "window_size": window_size, + } + item_transform = transforms.defaults.get_default_transform( model_name, "eval", transform_params ) - if dataset_params is None: - dataset_params = {} val_dataset = FramesDataset.from_dataset_path( dataset_path=dataset_path, split=split, item_transform=item_transform, - **dataset_params, ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, diff --git a/src/vak/eval/parametric_umap.py b/src/vak/eval/parametric_umap.py index 09dd8891b..107d8d844 100644 --- a/src/vak/eval/parametric_umap.py +++ b/src/vak/eval/parametric_umap.py @@ -1,4 +1,5 @@ """Function that evaluates trained models in the parametric UMAP family.""" + from __future__ import annotations import logging @@ -18,15 +19,12 @@ def eval_parametric_umap_model( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, checkpoint_path: str | pathlib.Path, output_dir: str | pathlib.Path, batch_size: int, num_workers: int, - transform_params: dict | None = None, - dataset_params: dict | None = None, split: str = "test", device: str | None = None, ) -> None: @@ -34,14 +32,12 @@ def eval_parametric_umap_model( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str, pathlib.Path - Path to dataset, e.g., a csv file generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. checkpoint_path : str, pathlib.Path Path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path @@ -51,14 +47,6 @@ def eval_parametric_umap_model( num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. split : str Split of dataset on which model should be evaluated. One of {'train', 'val', 'test'}. Default is 'test'. @@ -77,7 +65,7 @@ def eval_parametric_umap_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -95,18 +83,15 @@ def eval_parametric_umap_model( timenow = datetime.now().strftime("%y%m%d_%H%M%S") # ---------------- load data for evaluation ------------------------------------------------------------------------ - if transform_params is None: - transform_params = {} + model_name = model_config["name"] item_transform = transforms.defaults.get_default_transform( - model_name, "eval", transform_params + model_name, "eval" ) - if dataset_params is None: - dataset_params = {} val_dataset = ParametricUMAPDataset.from_dataset_path( dataset_path=dataset_path, split=split, transform=item_transform, - **dataset_params, + **dataset_config["params"], ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, diff --git a/src/vak/learncurve/frame_classification.py b/src/vak/learncurve/frame_classification.py index f363a946a..8ca2e11b7 100644 --- a/src/vak/learncurve/frame_classification.py +++ b/src/vak/learncurve/frame_classification.py @@ -1,4 +1,5 @@ """Function that generates results for a learning curve for frame classification models.""" + from __future__ import annotations import logging @@ -16,17 +17,12 @@ def learning_curve_for_frame_classification_model( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, batch_size: int, num_epochs: int, num_workers: int, results_path: str | pathlib.Path, - train_transform_params: dict | None = None, - train_dataset_params: dict | None = None, - val_transform_params: dict | None = None, - val_dataset_params: dict | None = None, post_tfm_kwargs: dict | None = None, normalize_spectrograms: bool = True, shuffle: bool = True, @@ -47,12 +43,12 @@ def learning_curve_for_frame_classification_model( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. dataset_path : str path to where dataset was saved as a csv. batch_size : int @@ -65,24 +61,6 @@ def learning_curve_for_frame_classification_model( Argument to torch.DataLoader. results_path : str, pathlib.Path Directory where results will be saved. - train_transform_params: dict, optional - Parameters for training data transform. - Passed as keyword arguments. - Optional, default is None. - train_dataset_params: dict, optional - Parameters for training dataset. - Passed as keyword arguments to - :class:`vak.datasets.frame_classification.WindowDataset`. - Optional, default is None. - val_transform_params: dict, optional - Parameters for validation data transform. - Passed as keyword arguments. - Optional, default is None. - val_dataset_params: dict, optional - Parameters for validation dataset. - Passed as keyword arguments to - :class:`vak.datasets.frame_classification.FramesDataset`. - Optional, default is None. previous_run_path : str, Path Path to directory containing dataset .csv files that represent subsets of training set, created by @@ -129,7 +107,7 @@ def learning_curve_for_frame_classification_model( Default is None, in which case training only stops after the specified number of epochs. """ # ---------------- pre-conditions ---------------------------------------------------------------------------------- - dataset_path = expanded_user_path(dataset_path) + dataset_path = expanded_user_path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -180,6 +158,9 @@ def learning_curve_for_frame_classification_model( # ---- main loop that creates "learning curve" --------------------------------------------------------------------- logger.info("Starting training for learning curve.") + model_name = model_config[ + "name" + ] # used below when getting checkpoint path, etc for train_dur, replicate_num in to_do: logger.info( f"Training model with training set of size: {train_dur}s, replicate number {replicate_num}.", @@ -205,16 +186,11 @@ def learning_curve_for_frame_classification_model( ) train_frame_classification_model( - model_name, model_config, - dataset_path, + dataset_config, batch_size, num_epochs, num_workers, - train_transform_params, - train_dataset_params, - val_transform_params, - val_dataset_params, results_path=results_path_this_replicate, normalize_spectrograms=normalize_spectrograms, shuffle=shuffle, @@ -260,15 +236,12 @@ def learning_curve_for_frame_classification_model( spect_scaler_path = None eval_frame_classification_model( - model_name, model_config, - dataset_path, + dataset_config, ckpt_path, labelmap_path, results_path_this_replicate, num_workers, - val_transform_params, - val_dataset_params, "test", spect_scaler_path, post_tfm_kwargs, diff --git a/src/vak/learncurve/learncurve.py b/src/vak/learncurve/learncurve.py index 781b601aa..0b6e443bf 100644 --- a/src/vak/learncurve/learncurve.py +++ b/src/vak/learncurve/learncurve.py @@ -1,4 +1,5 @@ """High-level function that generates results for a learning curve for all models.""" + from __future__ import annotations import logging @@ -12,16 +13,11 @@ def learning_curve( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, batch_size: int, num_epochs: int, num_workers: int, - train_transform_params: dict | None = None, - train_dataset_params: dict | None = None, - val_transform_params: dict | None = None, - val_dataset_params: dict | None = None, results_path: str | pathlib.Path = None, post_tfm_kwargs: dict | None = None, normalize_spectrograms: bool = True, @@ -43,14 +39,12 @@ def learning_curve( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - path to where dataset was saved as a csv. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int @@ -107,12 +101,13 @@ def learning_curve( Default is None, in which case training only stops after the specified number of epochs. """ # ---------------- pre-conditions ---------------------------------------------------------------------------------- - dataset_path = expanded_user_path(dataset_path) + dataset_path = expanded_user_path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) + model_name = model_config["name"] try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: @@ -121,16 +116,11 @@ def learning_curve( ) from e if model_family == "FrameClassificationModel": learning_curve_for_frame_classification_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, batch_size=batch_size, num_epochs=num_epochs, num_workers=num_workers, - train_transform_params=train_transform_params, - train_dataset_params=train_dataset_params, - val_transform_params=val_transform_params, - val_dataset_params=val_dataset_params, results_path=results_path, post_tfm_kwargs=post_tfm_kwargs, normalize_spectrograms=normalize_spectrograms, diff --git a/src/vak/metrics/util.py b/src/vak/metrics/util.py index 7bbdbc226..e9cf84dd1 100644 --- a/src/vak/metrics/util.py +++ b/src/vak/metrics/util.py @@ -8,6 +8,7 @@ https://setuptools.readthedocs.io/en/latest/setuptools.html#dynamic-discovery-of-services-and-plugins https://amir.rachum.com/blog/2017/07/28/python-entry-points/ """ + from .. import entry_points METRICS_ENTRY_POINT = "vak.models" diff --git a/src/vak/models/base.py b/src/vak/models/base.py index 2aa47022a..fd27b7ae3 100644 --- a/src/vak/models/base.py +++ b/src/vak/models/base.py @@ -1,6 +1,7 @@ """Base class for a model in ``vak``, that other families of models should subclass. """ + from __future__ import annotations import inspect diff --git a/src/vak/models/convencoder_umap.py b/src/vak/models/convencoder_umap.py index a7a894b23..7e06efe1c 100644 --- a/src/vak/models/convencoder_umap.py +++ b/src/vak/models/convencoder_umap.py @@ -5,6 +5,7 @@ with changes made by Tim Sainburg: https://github.com/lmcinnes/umap/issues/580#issuecomment-1368649550. """ + from __future__ import annotations import torch diff --git a/src/vak/models/decorator.py b/src/vak/models/decorator.py index a0aa717fe..5a0fff875 100644 --- a/src/vak/models/decorator.py +++ b/src/vak/models/decorator.py @@ -8,6 +8,7 @@ The subclass can then be instantiated and have all model methods. """ + from __future__ import annotations from typing import Type diff --git a/src/vak/models/definition.py b/src/vak/models/definition.py index 14b5435de..b3742d2a8 100644 --- a/src/vak/models/definition.py +++ b/src/vak/models/definition.py @@ -1,6 +1,7 @@ """Code that handles classes that represent the definition of a neural network model; the abstraction of how models are declared with code in vak.""" + from __future__ import annotations import dataclasses diff --git a/src/vak/models/ed_tcn.py b/src/vak/models/ed_tcn.py index 11a195531..38fd9a325 100644 --- a/src/vak/models/ed_tcn.py +++ b/src/vak/models/ed_tcn.py @@ -1,5 +1,6 @@ """ """ + from __future__ import annotations import torch diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index ce35dc401..305018e8c 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -2,6 +2,7 @@ where a model predicts a label for each frame in a time series, e.g., each time bin in a window from a spectrogram.""" + from __future__ import annotations import logging @@ -199,9 +200,9 @@ def training_step(self, batch: tuple, batch_idx: int): Scalar loss value computed by the loss function, ``self.loss``. """ - x, y = batch[0], batch[1] - out = self.network(x) - loss = self.loss(out, y) + frames, frame_labels = batch["frames"], batch["frame_labels"] + out = self.network(frames) + loss = self.loss(out, frame_labels) self.log("train_loss", loss, on_step=True) return loss diff --git a/src/vak/models/get.py b/src/vak/models/get.py index e4cc3ab06..b6f6a849c 100644 --- a/src/vak/models/get.py +++ b/src/vak/models/get.py @@ -1,5 +1,6 @@ """Function that gets an instance of a model, given its name and a configuration as a dict.""" + from __future__ import annotations import inspect diff --git a/src/vak/models/parametric_umap_model.py b/src/vak/models/parametric_umap_model.py index 67203b71c..4d2c2cb94 100644 --- a/src/vak/models/parametric_umap_model.py +++ b/src/vak/models/parametric_umap_model.py @@ -5,6 +5,7 @@ with changes made by Tim Sainburg: https://github.com/lmcinnes/umap/issues/580#issuecomment-1368649550. """ + from __future__ import annotations import pathlib diff --git a/src/vak/models/registry.py b/src/vak/models/registry.py index e9f01d23c..b187b2480 100644 --- a/src/vak/models/registry.py +++ b/src/vak/models/registry.py @@ -3,6 +3,7 @@ Makes it possible to register a model declared outside of ``vak`` with a decorator, so that the model can be used at runtime. """ + from __future__ import annotations import inspect diff --git a/src/vak/models/tweetynet.py b/src/vak/models/tweetynet.py index 62e7b58bf..b9631be59 100644 --- a/src/vak/models/tweetynet.py +++ b/src/vak/models/tweetynet.py @@ -6,6 +6,7 @@ Paper: https://elifesciences.org/articles/63853 Code: https://github.com/yardencsGitHub/tweetynet """ + from __future__ import annotations import torch diff --git a/src/vak/nets/tweetynet.py b/src/vak/nets/tweetynet.py index ed2ec5e7b..ab5f8defc 100644 --- a/src/vak/nets/tweetynet.py +++ b/src/vak/nets/tweetynet.py @@ -1,4 +1,5 @@ """TweetyNet model""" + from __future__ import annotations import torch diff --git a/src/vak/nn/loss/umap.py b/src/vak/nn/loss/umap.py index 8e59be403..077ee0ef2 100644 --- a/src/vak/nn/loss/umap.py +++ b/src/vak/nn/loss/umap.py @@ -1,4 +1,5 @@ """Parametric UMAP loss function.""" + from __future__ import annotations import warnings @@ -77,7 +78,7 @@ def umap_loss( distance_embedding = torch.cat( ( (embedding_to - embedding_from).norm(dim=1), - (embedding_neg_to - embedding_neg_from).norm(dim=1) + (embedding_neg_to - embedding_neg_from).norm(dim=1), # ``to`` method in next line to avoid error `Expected all tensors to be on the same device` ), dim=0, diff --git a/src/vak/nn/modules/activation.py b/src/vak/nn/modules/activation.py index 57a884496..5173ceee8 100644 --- a/src/vak/nn/modules/activation.py +++ b/src/vak/nn/modules/activation.py @@ -1,4 +1,5 @@ """Modules that act as activation functions.""" + import torch diff --git a/src/vak/nn/modules/conv.py b/src/vak/nn/modules/conv.py index 778e5abf3..c249a1b56 100644 --- a/src/vak/nn/modules/conv.py +++ b/src/vak/nn/modules/conv.py @@ -1,4 +1,5 @@ """Modules that perform neural network convolutions.""" + import torch from torch.nn import functional as F diff --git a/src/vak/plot/annot.py b/src/vak/plot/annot.py index fca7294d0..dcb8180f3 100644 --- a/src/vak/plot/annot.py +++ b/src/vak/plot/annot.py @@ -1,4 +1,5 @@ """functions for plotting annotations for vocalizations""" + import matplotlib.pyplot as plt import numpy as np from matplotlib.collections import LineCollection diff --git a/src/vak/plot/learncurve.py b/src/vak/plot/learncurve.py index c78cf0792..b304ec5c1 100644 --- a/src/vak/plot/learncurve.py +++ b/src/vak/plot/learncurve.py @@ -1,4 +1,5 @@ """functions to plot learning curve results""" + import os import pickle from configparser import ConfigParser diff --git a/src/vak/plot/spect.py b/src/vak/plot/spect.py index 286e357dd..99109649a 100644 --- a/src/vak/plot/spect.py +++ b/src/vak/plot/spect.py @@ -1,4 +1,5 @@ """functions for plotting spectrograms""" + import matplotlib.pyplot as plt from .annot import annotation diff --git a/src/vak/predict/frame_classification.py b/src/vak/predict/frame_classification.py index b029ea809..765fb3134 100644 --- a/src/vak/predict/frame_classification.py +++ b/src/vak/predict/frame_classification.py @@ -1,4 +1,5 @@ """Function that generates new inferences from trained models in the frame classification family.""" + from __future__ import annotations import json @@ -22,14 +23,11 @@ def predict_with_frame_classification_model( - model_name: str, model_config: dict, - dataset_path, + dataset_config: dict, checkpoint_path, labelmap_path, num_workers=2, - transform_params: dict | None = None, - dataset_params: dict | None = None, timebins_key="t", spect_scaler_path=None, device=None, @@ -39,75 +37,66 @@ def predict_with_frame_classification_model( majority_vote=False, save_net_outputs=False, ): - """Make predictions on a dataset with a trained model. + """Make predictions on a dataset with a trained + :class:`~vak.models.FrameClassificationModel`. Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. - checkpoint_path : str - path to directory with checkpoint files saved by Torch, to reload model - labelmap_path : str - path to 'labelmap.json' file. - num_workers : int - Number of processes to use for parallel loading of data. - Argument to torch.DataLoader. Default is 2. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. - spect_key : str - key for accessing spectrogram in files. Default is 's'. - timebins_key : str - key for accessing vector of time bins in files. Default is 't'. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. - spect_scaler_path : str - path to a saved SpectScaler object used to normalize spectrograms. - If spectrograms were normalized and this is not provided, will give - incorrect results. - annot_csv_filename : str - name of .csv file containing predicted annotations. - Default is None, in which case the name of the dataset .csv - is used, with '.annot.csv' appended to it. - output_dir : str, Path - path to location where .csv containing predicted annotation - should be saved. Defaults to current working directory. - min_segment_dur : float - minimum duration of segment, in seconds. If specified, then - any segment with a duration less than min_segment_dur is - removed from lbl_tb. Default is None, in which case no - segments are removed. - majority_vote : bool - if True, transform segments containing multiple labels - into segments with a single label by taking a "majority vote", - i.e. assign all time bins in the segment the most frequently - occurring label in the segment. This transform can only be - applied if the labelmap contains an 'unlabeled' label, - because unlabeled segments makes it possible to identify - the labeled segments. Default is False. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + checkpoint_path : str + path to directory with checkpoint files saved by Torch, to reload model + labelmap_path : str + path to 'labelmap.json' file. + num_workers : int + Number of processes to use for parallel loading of data. + Argument to torch.DataLoader. Default is 2. + spect_key : str + key for accessing spectrogram in files. Default is 's'. + timebins_key : str + key for accessing vector of time bins in files. Default is 't'. + device : str + Device on which to work with model + data. + Defaults to 'cuda' if torch.cuda.is_available is True. + spect_scaler_path : str + path to a saved SpectScaler object used to normalize spectrograms. + If spectrograms were normalized and this is not provided, will give + incorrect results. + annot_csv_filename : str + name of .csv file containing predicted annotations. + Default is None, in which case the name of the dataset .csv + is used, with '.annot.csv' appended to it. + output_dir : str, Path + path to location where .csv containing predicted annotation + should be saved. Defaults to current working directory. + min_segment_dur : float + minimum duration of segment, in seconds. If specified, then + any segment with a duration less than min_segment_dur is + removed from lbl_tb. Default is None, in which case no + segments are removed. + majority_vote : bool + if True, transform segments containing multiple labels + into segments with a single label by taking a "majority vote", + i.e. assign all time bins in the segment the most frequently + occurring label in the segment. This transform can only be + applied if the labelmap contains an 'unlabeled' label, + because unlabeled segments makes it possible to identify + the labeled segments. Default is False. save_net_outputs : bool - if True, save 'raw' outputs of neural networks - before they are converted to annotations. Default is False. - Typically the output will be "logits" - to which a softmax transform might be applied. - For each item in the dataset--each row in the `dataset_path` .csv-- - the output will be saved in a separate file in `output_dir`, - with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a - spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, - and the network is `TweetyNet`, then the net output file - will be `gy6or6_032312_081416.tweetynet.output.npz`. + if True, save 'raw' outputs of neural networks + before they are converted to annotations. Default is False. + Typically the output will be "logits" + to which a softmax transform might be applied. + For each item in the dataset--each row in the `dataset_path` .csv-- + the output will be saved in a separate file in `output_dir`, + with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a + spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, + and the network is `TweetyNet`, then the net output file + will be `gy6or6_032312_081416.tweetynet.output.npz`. """ for path, path_name in zip( (checkpoint_path, labelmap_path, spect_scaler_path), @@ -119,7 +108,7 @@ def predict_with_frame_classification_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -146,9 +135,21 @@ def predict_with_frame_classification_model( logger.info("Not loading SpectScaler, no path was specified") spect_standardizer = None - if transform_params is None: - transform_params = {} - transform_params.update({"spect_standardizer": spect_standardizer}) + model_name = model_config["name"] + # TODO: move this into datapipe once each datapipe uses a fixed set of transforms + # that will require adding `spect_standardizer`` as a parameter to the datapipe, + # maybe rename to `frames_standardizer`? + try: + window_size = dataset_config["params"]["window_size"] + except KeyError as e: + raise KeyError( + f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table " + f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}" + ) from e + transform_params = { + "spect_standardizer": spect_standardizer, + "window_size": window_size, + } item_transform = transforms.defaults.get_default_transform( model_name, "predict", transform_params ) @@ -165,13 +166,12 @@ def predict_with_frame_classification_model( logger.info( f"loading dataset to predict from csv path: {dataset_csv_path}" ) - if dataset_params is None: - dataset_params = {} + + # TODO: fix this when we build transforms into datasets; pass in `window_size` here pred_dataset = FramesDataset.from_dataset_path( dataset_path=dataset_path, split="predict", item_transform=item_transform, - **dataset_params, ) pred_loader = torch.utils.data.DataLoader( diff --git a/src/vak/predict/parametric_umap.py b/src/vak/predict/parametric_umap.py index 66955f2a9..4e54336f4 100644 --- a/src/vak/predict/parametric_umap.py +++ b/src/vak/predict/parametric_umap.py @@ -1,4 +1,5 @@ """Function that generates new inferences from trained models in the frame classification family.""" + from __future__ import annotations import logging @@ -17,9 +18,8 @@ def predict_with_parametric_umap_model( - model_name: str, model_config: dict, - dataset_path, + dataset_config: dict, checkpoint_path, num_workers=2, transform_params: dict | None = None, @@ -28,23 +28,22 @@ def predict_with_parametric_umap_model( device=None, output_dir=None, ): - """Make predictions on a dataset with a trained model. + """Make predictions on a dataset with a trained + :class:`vak.models.ParametricUMAPModel`. Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. - checkpoint_path : str - path to directory with checkpoint files saved by Torch, to reload model - num_workers : int - Number of processes to use for parallel loading of data. - Argument to torch.DataLoader. Default is 2. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + checkpoint_path : str + path to directory with checkpoint files saved by Torch, to reload model + num_workers : int + Number of processes to use for parallel loading of data. + Argument to torch.DataLoader. Default is 2. transform_params: dict, optional Parameters for data transform. Passed as keyword arguments. @@ -53,18 +52,18 @@ def predict_with_parametric_umap_model( Parameters for dataset. Passed as keyword arguments. Optional, default is None. - timebins_key : str - key for accessing vector of time bins in files. Default is 't'. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. - annot_csv_filename : str - name of .csv file containing predicted annotations. - Default is None, in which case the name of the dataset .csv - is used, with '.annot.csv' appended to it. - output_dir : str, Path - path to location where .csv containing predicted annotation - should be saved. Defaults to current working directory. + timebins_key : str + key for accessing vector of time bins in files. Default is 't'. + device : str + Device on which to work with model + data. + Defaults to 'cuda' if torch.cuda.is_available is True. + annot_csv_filename : str + name of .csv file containing predicted annotations. + Default is None, in which case the name of the dataset .csv + is used, with '.annot.csv' appended to it. + output_dir : str, Path + path to location where .csv containing predicted annotation + should be saved. Defaults to current working directory. """ for path, path_name in zip( (checkpoint_path,), @@ -76,7 +75,7 @@ def predict_with_parametric_umap_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -102,12 +101,14 @@ def predict_with_parametric_umap_model( device = get_default_device() # ---------------- load data for prediction ------------------------------------------------------------------------ - if transform_params is None: - transform_params = {} - if "padding" not in transform_params and model_name == "ConvEncoderUMAP": - padding = models.convencoder_umap.get_default_padding(metadata.shape) - transform_params["padding"] = padding - + model_name = model_config["name"] + # TODO: fix this when we build transforms into datasets + transform_params = { + "padding": dataset_config["params"].get( + "padding", + models.convencoder_umap.get_default_padding(metadata.shape), + ) + } item_transform = transforms.defaults.get_default_transform( model_name, "predict", transform_params ) @@ -117,13 +118,11 @@ def predict_with_parametric_umap_model( f"loading dataset to predict from csv path: {dataset_csv_path}" ) - if dataset_params is None: - dataset_params = {} pred_dataset = ParametricUMAPDataset.from_dataset_path( dataset_path=dataset_path, split="predict", transform=item_transform, - **dataset_params, + **dataset_config["params"], ) pred_loader = torch.utils.data.DataLoader( diff --git a/src/vak/predict/predict_.py b/src/vak/predict/predict_.py index 29208d0f2..60373b11d 100644 --- a/src/vak/predict/predict_.py +++ b/src/vak/predict/predict_.py @@ -1,4 +1,5 @@ """High-level function that generates new inferences from trained models.""" + from __future__ import annotations import logging @@ -14,14 +15,11 @@ def predict( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, checkpoint_path: str | pathlib.Path, labelmap_path: str | pathlib.Path, num_workers: int = 2, - transform_params: dict | None = None, - dataset_params: dict | None = None, timebins_key: str = "t", spect_scaler_path: str | pathlib.Path | None = None, device: str | None = None, @@ -35,72 +33,62 @@ def predict( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict Model configuration in a ``dict``, as loaded from a .toml file, and used by the model method ``from_config``. - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. - checkpoint_path : str - path to directory with checkpoint files saved by Torch, to reload model - labelmap_path : str - path to 'labelmap.json' file. - window_size : int - size of windows taken from spectrograms, in number of time bins, - shown to neural networks - num_workers : int - Number of processes to use for parallel loading of data. - Argument to torch.DataLoader. Default is 2. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. - timebins_key : str - key for accessing vector of time bins in files. Default is 't'. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. - spect_scaler_path : str - path to a saved SpectScaler object used to normalize spectrograms. - If spectrograms were normalized and this is not provided, will give - incorrect results. - annot_csv_filename : str - name of .csv file containing predicted annotations. - Default is None, in which case the name of the dataset .csv - is used, with '.annot.csv' appended to it. - output_dir : str, Path - path to location where .csv containing predicted annotation - should be saved. Defaults to current working directory. - min_segment_dur : float - minimum duration of segment, in seconds. If specified, then - any segment with a duration less than min_segment_dur is - removed from lbl_tb. Default is None, in which case no - segments are removed. - majority_vote : bool - if True, transform segments containing multiple labels - into segments with a single label by taking a "majority vote", - i.e. assign all time bins in the segment the most frequently - occurring label in the segment. This transform can only be - applied if the labelmap contains an 'unlabeled' label, - because unlabeled segments makes it possible to identify - the labeled segments. Default is False. + dataset_path : str + Path to dataset, e.g., a csv file generated by running ``vak prep``. + checkpoint_path : str + path to directory with checkpoint files saved by Torch, to reload model + labelmap_path : str + path to 'labelmap.json' file. + window_size : int + size of windows taken from spectrograms, in number of time bins, + shown to neural networks + num_workers : int + Number of processes to use for parallel loading of data. + Argument to torch.DataLoader. Default is 2. + timebins_key : str + key for accessing vector of time bins in files. Default is 't'. + device : str + Device on which to work with model + data. + Defaults to 'cuda' if torch.cuda.is_available is True. + spect_scaler_path : str + path to a saved SpectScaler object used to normalize spectrograms. + If spectrograms were normalized and this is not provided, will give + incorrect results. + annot_csv_filename : str + name of .csv file containing predicted annotations. + Default is None, in which case the name of the dataset .csv + is used, with '.annot.csv' appended to it. + output_dir : str, Path + path to location where .csv containing predicted annotation + should be saved. Defaults to current working directory. + min_segment_dur : float + minimum duration of segment, in seconds. If specified, then + any segment with a duration less than min_segment_dur is + removed from lbl_tb. Default is None, in which case no + segments are removed. + majority_vote : bool + if True, transform segments containing multiple labels + into segments with a single label by taking a "majority vote", + i.e. assign all time bins in the segment the most frequently + occurring label in the segment. This transform can only be + applied if the labelmap contains an 'unlabeled' label, + because unlabeled segments makes it possible to identify + the labeled segments. Default is False. save_net_outputs : bool - if True, save 'raw' outputs of neural networks - before they are converted to annotations. Default is False. - Typically the output will be "logits" - to which a softmax transform might be applied. - For each item in the dataset--each row in the `dataset_path` .csv-- - the output will be saved in a separate file in `output_dir`, - with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a - spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, - and the network is `TweetyNet`, then the net output file - will be `gy6or6_032312_081416.tweetynet.output.npz`. + If True, save 'raw' outputs of neural networks + before they are converted to annotations. Default is False. + Typically the output will be "logits" + to which a softmax transform might be applied. + For each item in the dataset--each row in the `dataset_path` .csv-- + the output will be saved in a separate file in `output_dir`, + with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a + spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, + and the network is `TweetyNet`, then the net output file + will be `gy6or6_032312_081416.tweetynet.output.npz`. """ for path, path_name in zip( (checkpoint_path, labelmap_path, spect_scaler_path), @@ -112,7 +100,7 @@ def predict( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -131,6 +119,7 @@ def predict( if device is None: device = get_default_device() + model_name = model_config["name"] try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: @@ -139,14 +128,11 @@ def predict( ) from e if model_family == "FrameClassificationModel": predict_with_frame_classification_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, checkpoint_path=checkpoint_path, labelmap_path=labelmap_path, num_workers=num_workers, - transform_params=transform_params, - dataset_params=dataset_params, timebins_key=timebins_key, spect_scaler_path=spect_scaler_path, device=device, diff --git a/src/vak/prep/audio_dataset.py b/src/vak/prep/audio_dataset.py index e43444684..04674d00b 100644 --- a/src/vak/prep/audio_dataset.py +++ b/src/vak/prep/audio_dataset.py @@ -176,9 +176,11 @@ def abspath(a_path): [ abspath(audio_path), abspath(annot_path), - annot_format - if annot_format - else constants.NO_ANNOTATION_FORMAT, + ( + annot_format + if annot_format + else constants.NO_ANNOTATION_FORMAT + ), samplerate, sample_dur, audio_dur, diff --git a/src/vak/prep/constants.py b/src/vak/prep/constants.py index 68399dd4c..5657c2f0b 100644 --- a/src/vak/prep/constants.py +++ b/src/vak/prep/constants.py @@ -2,6 +2,7 @@ Defined in a separate module to minimize circular imports. """ + from . import frame_classification, parametric_umap VALID_PURPOSES = frozenset( diff --git a/src/vak/prep/dataset_df_helper.py b/src/vak/prep/dataset_df_helper.py index 81f73f9ba..2146f76e2 100644 --- a/src/vak/prep/dataset_df_helper.py +++ b/src/vak/prep/dataset_df_helper.py @@ -1,4 +1,5 @@ """Helper functions for working with datasets represented as a pandas.DataFrame""" + from __future__ import annotations import pathlib diff --git a/src/vak/prep/frame_classification/assign_samples_to_splits.py b/src/vak/prep/frame_classification/assign_samples_to_splits.py index fc272a93e..ef74eae3f 100644 --- a/src/vak/prep/frame_classification/assign_samples_to_splits.py +++ b/src/vak/prep/frame_classification/assign_samples_to_splits.py @@ -5,6 +5,7 @@ Helper function called by :func:`vak.prep.frame_classification.prep_frame_classification_dataset`. """ + from __future__ import annotations import logging diff --git a/src/vak/prep/frame_classification/frame_classification.py b/src/vak/prep/frame_classification/frame_classification.py index 5bd9469b1..8ce6d29fe 100644 --- a/src/vak/prep/frame_classification/frame_classification.py +++ b/src/vak/prep/frame_classification/frame_classification.py @@ -1,5 +1,6 @@ """Function that prepares datasets for neural network models that perform the frame classification task.""" + from __future__ import annotations import json diff --git a/src/vak/prep/frame_classification/learncurve.py b/src/vak/prep/frame_classification/learncurve.py index 97ba55525..bae335c5d 100644 --- a/src/vak/prep/frame_classification/learncurve.py +++ b/src/vak/prep/frame_classification/learncurve.py @@ -1,5 +1,6 @@ """Functionality to prepare splits of frame classification datasets to generate a learning curve.""" + from __future__ import annotations import logging diff --git a/src/vak/prep/frame_classification/make_splits.py b/src/vak/prep/frame_classification/make_splits.py index e4fd01564..2af4b586d 100644 --- a/src/vak/prep/frame_classification/make_splits.py +++ b/src/vak/prep/frame_classification/make_splits.py @@ -1,4 +1,5 @@ """Helper functions for frame classification dataset prep.""" + from __future__ import annotations import collections @@ -437,9 +438,11 @@ def _save_dataset_arrays_and_return_index_arrays( ] = frames_paths frame_labels_npy_paths = [ - sample.frame_labels_npy_path - if isinstance(sample.frame_labels_npy_path, str) - else None + ( + sample.frame_labels_npy_path + if isinstance(sample.frame_labels_npy_path, str) + else None + ) for sample in samples ] split_df[ diff --git a/src/vak/prep/frame_classification/validators.py b/src/vak/prep/frame_classification/validators.py index 91d56be7b..35e023771 100644 --- a/src/vak/prep/frame_classification/validators.py +++ b/src/vak/prep/frame_classification/validators.py @@ -1,4 +1,5 @@ """Validators for frame classification datasets""" + from __future__ import annotations import pandas as pd diff --git a/src/vak/prep/parametric_umap/dataset_arrays.py b/src/vak/prep/parametric_umap/dataset_arrays.py index 67e224ae7..82080ae5d 100644 --- a/src/vak/prep/parametric_umap/dataset_arrays.py +++ b/src/vak/prep/parametric_umap/dataset_arrays.py @@ -1,6 +1,7 @@ """Helper functions for `vak.prep.dimensionality_reduction` module that handle array files. """ + from __future__ import annotations import logging diff --git a/src/vak/prep/sequence_dataset.py b/src/vak/prep/sequence_dataset.py index 11ec2df86..067b7807a 100644 --- a/src/vak/prep/sequence_dataset.py +++ b/src/vak/prep/sequence_dataset.py @@ -1,4 +1,5 @@ """Helper functions for datasets annotated as sequences.""" + from __future__ import annotations import numpy as np diff --git a/src/vak/prep/spectrogram_dataset/__init__.py b/src/vak/prep/spectrogram_dataset/__init__.py index 54491e66f..15f1cc474 100644 --- a/src/vak/prep/spectrogram_dataset/__init__.py +++ b/src/vak/prep/spectrogram_dataset/__init__.py @@ -1,5 +1,6 @@ """Functions for preparing a dataset for neural network models from a dataset of spectrograms.""" + from .prep import prep_spectrogram_dataset __all__ = [ diff --git a/src/vak/prep/spectrogram_dataset/spect.py b/src/vak/prep/spectrogram_dataset/spect.py index d4d84ada0..f07562180 100644 --- a/src/vak/prep/spectrogram_dataset/spect.py +++ b/src/vak/prep/spectrogram_dataset/spect.py @@ -5,6 +5,7 @@ spectrogram adapted from code by Kyle Kastner and Tim Sainburg https://github.com/timsainb/python_spectrograms_and_inversion """ + import numpy as np from matplotlib.mlab import specgram from scipy.signal import butter, lfilter @@ -89,9 +90,9 @@ def spectrogram( spect[spect < thresh] = thresh else: if thresh: - spect[ - spect < thresh - ] = thresh # set anything less than the threshold as the threshold + spect[spect < thresh] = ( + thresh # set anything less than the threshold as the threshold + ) if freq_cutoffs: f_inds = np.nonzero( diff --git a/src/vak/prep/spectrogram_dataset/spect_helper.py b/src/vak/prep/spectrogram_dataset/spect_helper.py index 924e855b5..f0115e922 100644 --- a/src/vak/prep/spectrogram_dataset/spect_helper.py +++ b/src/vak/prep/spectrogram_dataset/spect_helper.py @@ -4,6 +4,7 @@ The columns of the dataframe are specified by :const:`vak.prep.spectrogram_dataset.spect_helper.DF_COLUMNS`. """ + from __future__ import annotations import logging @@ -239,9 +240,11 @@ def abspath(a_path): abspath(audio_path), abspath(spect_path), abspath(annot_path), - annot_format - if annot_format - else constants.NO_ANNOTATION_FORMAT, + ( + annot_format + if annot_format + else constants.NO_ANNOTATION_FORMAT + ), spect_dur, timebin_dur, ] diff --git a/src/vak/prep/split/split.py b/src/vak/prep/split/split.py index 23d37dd49..35ace9d26 100644 --- a/src/vak/prep/split/split.py +++ b/src/vak/prep/split/split.py @@ -1,5 +1,6 @@ """Functions for creating splits of datasets used with neural network models, such as the standard train-val-test splits used with supervised learning methods.""" + from __future__ import annotations import logging diff --git a/src/vak/prep/unit_dataset/unit_dataset.py b/src/vak/prep/unit_dataset/unit_dataset.py index 76a0e29b0..7d861b65a 100644 --- a/src/vak/prep/unit_dataset/unit_dataset.py +++ b/src/vak/prep/unit_dataset/unit_dataset.py @@ -1,5 +1,6 @@ """Functions for making a dataset of units from sequences, as used to train dimensionality reduction models.""" + from __future__ import annotations import logging diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index 256daaa84..25e07a062 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -1,4 +1,5 @@ """Function that trains models in the frame classification family.""" + from __future__ import annotations import datetime @@ -26,16 +27,11 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float: def train_frame_classification_model( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, batch_size: int, num_epochs: int, num_workers: int, - train_transform_params: dict | None = None, - train_dataset_params: dict | None = None, - val_transform_params: dict | None = None, - val_dataset_params: dict | None = None, checkpoint_path: str | pathlib.Path | None = None, spect_scaler_path: str | pathlib.Path | None = None, results_path: str | pathlib.Path | None = None, @@ -58,14 +54,12 @@ def train_frame_classification_model( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - Path to dataset, a directory generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int @@ -74,31 +68,7 @@ def train_frame_classification_model( num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. - train_transform_params: dict, optional - Parameters for training data transform. - Passed as keyword arguments. - Optional, default is None. - train_dataset_params: dict, optional - Parameters for training dataset. - Passed as keyword arguments to - :class:`vak.datasets.frame_classification.WindowDataset`. - Optional, default is None. - val_transform_params: dict, optional - Parameters for validation data transform. - Passed as keyword arguments. Optional, default is None. - val_dataset_params: dict, optional - Parameters for validation dataset. - Passed as keyword arguments to - :class:`vak.datasets.frame_classification.FramesDataset`. - Optional, default is None. - dataset_csv_path - Path to csv file representing splits of dataset, - e.g., such a file generated by running ``vak prep``. - This parameter is used by :func:`vak.core.learncurve` to specify - different splits to use, when generating results for a learning curve. - If this argument is specified, the csv file must be inside the directory - ``dataset_path``. checkpoint_path : str, pathlib.Path path to a checkpoint file, e.g., one generated by a previous run of ``vak.core.train``. @@ -157,14 +127,14 @@ def train_frame_classification_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) logger.info( - f"Loading dataset from path: {dataset_path}", + f"Loading dataset from `dataset_path`: {dataset_path}", ) metadata = datasets.frame_classification.Metadata.from_dataset_path( dataset_path @@ -239,22 +209,31 @@ def train_frame_classification_model( ) spect_standardizer = None - if train_transform_params is None: - train_transform_params = {} - train_transform_params.update({"spect_standardizer": spect_standardizer}) - transform, target_transform = transforms.defaults.get_default_transform( - model_name, "train", transform_kwargs=train_transform_params + model_name = model_config["name"] + # TODO: move this into datapipe once each datapipe uses a fixed set of transforms + # that will require adding `spect_standardizer`` as a parameter to the datapipe, + # maybe rename to `frames_standardizer`? + try: + window_size = dataset_config["params"]["window_size"] + except KeyError as e: + raise KeyError( + f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table " + f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}" + ) from e + transform_kwargs = { + "spect_standardizer": spect_standardizer, + "window_size": window_size, + } + train_transform = transforms.defaults.get_default_transform( + model_name, "train", transform_kwargs=transform_kwargs ) - if train_dataset_params is None: - train_dataset_params = {} train_dataset = WindowDataset.from_dataset_path( dataset_path=dataset_path, split="train", subset=subset, - transform=transform, - target_transform=target_transform, - **train_dataset_params, + item_transform=train_transform, + **dataset_config["params"], ) logger.info( f"Duration of WindowDataset used for training, in seconds: {train_dataset.duration}", @@ -277,19 +256,15 @@ def train_frame_classification_model( f"Total duration of validation split from dataset (in s): {val_dur}", ) - if val_transform_params is None: - val_transform_params = {} - val_transform_params.update({"spect_standardizer": spect_standardizer}) - item_transform = transforms.defaults.get_default_transform( - model_name, "eval", val_transform_params + # NOTE: we use same `transform_kwargs` here; will need to change to a `dataset_param` + # when we factor transform *into* fixed DataPipes as above + val_transform = transforms.defaults.get_default_transform( + model_name, "eval", transform_kwargs ) - if val_dataset_params is None: - val_dataset_params = {} val_dataset = FramesDataset.from_dataset_path( dataset_path=dataset_path, split="val", - item_transform=item_transform, - **val_dataset_params, + item_transform=val_transform, ) logger.info( f"Duration of FramesDataset used for evaluation, in seconds: {val_dataset.duration}", diff --git a/src/vak/train/parametric_umap.py b/src/vak/train/parametric_umap.py index 675dac90d..f9180e5c0 100644 --- a/src/vak/train/parametric_umap.py +++ b/src/vak/train/parametric_umap.py @@ -1,4 +1,5 @@ """Function that trains models in the Parametric UMAP family.""" + from __future__ import annotations import datetime @@ -77,16 +78,11 @@ def get_trainer( def train_parametric_umap_model( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, batch_size: int, num_epochs: int, num_workers: int, - train_transform_params: dict | None = None, - train_dataset_params: dict | None = None, - val_transform_params: dict | None = None, - val_dataset_params: dict | None = None, checkpoint_path: str | pathlib.Path | None = None, root_results_dir: str | pathlib.Path | None = None, results_path: str | pathlib.Path | None = None, @@ -107,14 +103,12 @@ def train_parametric_umap_model( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - Path to dataset, a directory generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int @@ -123,16 +117,6 @@ def train_parametric_umap_model( num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. - train_dataset_params: dict, optional - Parameters for training dataset. - Passed as keyword arguments to - :class:`vak.datasets.parametric_umap.ParametricUMAP`. - Optional, default is None. - val_dataset_params: dict, optional - Parameters for validation dataset. - Passed as keyword arguments to - :class:`vak.datasets.parametric_umap.ParametricUMAP`. - Optional, default is None. checkpoint_path : str, pathlib.Path, optional path to a checkpoint file, e.g., one generated by a previous run of ``vak.core.train``. @@ -175,7 +159,7 @@ def train_parametric_umap_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -218,20 +202,18 @@ def train_parametric_umap_model( f"Total duration of training split from dataset (in s): {train_dur}", ) - if train_transform_params is None: - train_transform_params = {} - transform = transforms.defaults.get_default_transform( - model_name, "train", train_transform_params + model_name = model_config["name"] + train_transform = transforms.defaults.get_default_transform( + model_name, "train" ) - if train_dataset_params is None: - train_dataset_params = {} + dataset_params = dataset_config["params"] train_dataset = ParametricUMAPDataset.from_dataset_path( dataset_path=dataset_path, split="train", subset=subset, - transform=transform, - **train_dataset_params, + transform=train_transform, + **dataset_params, ) logger.info( f"Duration of ParametricUMAPDataset used for training, in seconds: {train_dataset.duration}", @@ -245,18 +227,14 @@ def train_parametric_umap_model( # ---------------- load validation set (if there is one) ----------------------------------------------------------- if val_step: - if val_transform_params is None: - val_transform_params = {} transform = transforms.defaults.get_default_transform( - model_name, "eval", val_transform_params + model_name, "eval" ) - if val_dataset_params is None: - val_dataset_params = {} val_dataset = ParametricUMAPDataset.from_dataset_path( dataset_path=dataset_path, split="val", transform=transform, - **val_dataset_params, + **dataset_params, ) logger.info( f"Duration of ParametricUMAPDataset used for validation, in seconds: {val_dataset.duration}", diff --git a/src/vak/train/train_.py b/src/vak/train/train_.py index 79ee2897f..96926967d 100644 --- a/src/vak/train/train_.py +++ b/src/vak/train/train_.py @@ -1,4 +1,5 @@ """High-level function that trains models.""" + from __future__ import annotations import logging @@ -13,16 +14,11 @@ def train( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, batch_size: int, num_epochs: int, num_workers: int, - train_transform_params: dict | None = None, - train_dataset_params: dict | None = None, - val_transform_params: dict | None = None, - val_dataset_params: dict | None = None, checkpoint_path: str | pathlib.Path | None = None, spect_scaler_path: str | pathlib.Path | None = None, results_path: str | pathlib.Path | None = None, @@ -42,14 +38,12 @@ def train( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - Path to dataset, a directory generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. window_size : int size of windows taken from spectrograms, in number of time bins, shown to neural networks @@ -61,22 +55,6 @@ def train( num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. - train_transform_params: dict, optional - Parameters for training data transform. - Passed as keyword arguments. - Optional, default is None. - train_dataset_params: dict, optional - Parameters for training dataset. - Passed as keyword arguments. - Optional, default is None. - val_transform_params: dict, optional - Parameters for validation data transform. - Passed as keyword arguments. - Optional, default is None. - val_dataset_params: dict, optional - Parameters for validation dataset. - Passed as keyword arguments. - Optional, default is None. checkpoint_path : str, pathlib.Path Path to a checkpoint file, e.g., one generated by a previous run of ``vak.core.train``. @@ -152,12 +130,13 @@ def train( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) + model_name = model_config["name"] try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: @@ -166,16 +145,11 @@ def train( ) from e if model_family == "FrameClassificationModel": train_frame_classification_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, batch_size=batch_size, num_epochs=num_epochs, num_workers=num_workers, - train_transform_params=train_transform_params, - train_dataset_params=train_dataset_params, - val_transform_params=val_transform_params, - val_dataset_params=val_dataset_params, checkpoint_path=checkpoint_path, spect_scaler_path=spect_scaler_path, results_path=results_path, @@ -189,16 +163,11 @@ def train( ) elif model_family == "ParametricUMAPModel": train_parametric_umap_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, batch_size=batch_size, num_epochs=num_epochs, num_workers=num_workers, - train_transform_params=train_transform_params, - train_dataset_params=train_dataset_params, - val_transform_params=val_transform_params, - val_dataset_params=val_dataset_params, checkpoint_path=checkpoint_path, results_path=results_path, shuffle=shuffle, diff --git a/src/vak/transforms/defaults/frame_classification.py b/src/vak/transforms/defaults/frame_classification.py index c4abd0f34..9d20cfe50 100644 --- a/src/vak/transforms/defaults/frame_classification.py +++ b/src/vak/transforms/defaults/frame_classification.py @@ -8,6 +8,7 @@ needed for specific neural network models, e.g., whether the returned output includes a mask to crop off padding that was added. """ + from __future__ import annotations from typing import Callable @@ -26,32 +27,32 @@ def __init__( ): if spect_standardizer is not None: if isinstance(spect_standardizer, vak_transforms.StandardizeSpect): - source_transform = [spect_standardizer] + frames_transform = [spect_standardizer] else: raise TypeError( f"invalid type for spect_standardizer: {type(spect_standardizer)}. " "Should be an instance of vak.transforms.StandardizeSpect" ) else: - source_transform = [] + frames_transform = [] - source_transform.extend( + frames_transform.extend( [ vak_transforms.ToFloatTensor(), vak_transforms.AddChannel(), ] ) - self.source_transform = torchvision.transforms.Compose( - source_transform + self.frames_transform = torchvision.transforms.Compose( + frames_transform ) - self.annot_transform = vak_transforms.ToLongTensor() + self.frame_labels_transform = vak_transforms.ToLongTensor() - def __call__(self, source, annot, spect_path=None): - source = self.source_transform(source) - annot = self.annot_transform(annot) + def __call__(self, frames, frame_labels, spect_path=None): + frames = self.frames_transform(frames) + frame_labels = self.frame_labels_transform(frame_labels) item = { - "frames": source, - "frame_labels": annot, + "frames": frames, + "frame_labels": frame_labels, } if spect_path is not None: @@ -199,15 +200,18 @@ def __call__(self, frames, frames_path=None): def get_default_frame_classification_transform( - mode: str, transform_kwargs: dict + mode: str, transform_kwargs: dict | None = None ) -> tuple[Callable, Callable] | Callable: """Get default transform for frame classification model. Parameters ---------- mode : str - transform_kwargs : dict - A dict with the following key-value pairs: + transform_kwargs : dict, optional + Keyword arguments for transform class. + Default is None. + If supplied, should be a :class:`dict`, + that can include the following key-value pairs: spect_standardizer : vak.transforms.StandardizeSpect instance that has already been fit to dataset, using fit_df method. Default is None, in which case no standardization transform is applied. @@ -226,8 +230,10 @@ def get_default_frame_classification_transform( Returns ------- - + transform: TrainItemTransform, EvalItemTransform, or PredictItemTransform """ + if transform_kwargs is None: + transform_kwargs = {} spect_standardizer = transform_kwargs.get("spect_standardizer", None) # regardless of mode, transform always starts with StandardizeSpect, if used if spect_standardizer is not None: @@ -238,21 +244,7 @@ def get_default_frame_classification_transform( ) if mode == "train": - if spect_standardizer is not None: - transform = [spect_standardizer] - else: - transform = [] - - transform.extend( - [ - vak_transforms.ToFloatTensor(), - vak_transforms.AddChannel(), - ] - ) - transform = torchvision.transforms.Compose(transform) - - target_transform = vak_transforms.ToLongTensor() - return transform, target_transform + return TrainItemTransform(spect_standardizer) elif mode == "predict": item_transform = PredictItemTransform( diff --git a/src/vak/transforms/defaults/get.py b/src/vak/transforms/defaults/get.py index 0851d515c..3d567bde7 100644 --- a/src/vak/transforms/defaults/get.py +++ b/src/vak/transforms/defaults/get.py @@ -1,16 +1,19 @@ """Helper function that gets default transforms for a model.""" + from __future__ import annotations +from typing import Callable, Literal + from ... import models from . import frame_classification, parametric_umap def get_default_transform( model_name: str, - mode: str, - transform_kwargs: dict, -): - """Get default transforms for a model, + mode: Literal["eval", "predict", "train"], + transform_kwargs: dict | None = None, +) -> Callable: + """Get default transform for a model, according to its family and what mode the model is being used in. @@ -19,14 +22,13 @@ def get_default_transform( model_name : str Name of model. mode : str - one of {'train', 'eval', 'predict'}. Determines set of transforms. + One of {'eval', 'predict', 'train'}. Returns ------- - transform, target_transform : callable - one or more vak transforms to be applied to inputs x and, during training, the target y. - If more than one transform, they are combined into an instance of torchvision.transforms.Compose. - Note that when mode is 'predict', the target transform is None. + item_transform : callable + Transform to be applied to input :math:`x` to a model and, + during training, the target :math:`y`. """ try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] diff --git a/src/vak/transforms/defaults/parametric_umap.py b/src/vak/transforms/defaults/parametric_umap.py index 83c568b06..a62a7c29b 100644 --- a/src/vak/transforms/defaults/parametric_umap.py +++ b/src/vak/transforms/defaults/parametric_umap.py @@ -1,4 +1,5 @@ """Default transforms for Parametric UMAP models.""" + from __future__ import annotations import torchvision.transforms @@ -7,18 +8,22 @@ def get_default_parametric_umap_transform( - transform_kwargs, + transform_kwargs: dict | None = None, ) -> torchvision.transforms.Compose: """Get default transform for frame classification model. Parameters ---------- - transform_kwargs : dict + transform_kwargs : dict, optional + Keyword arguments for transform class. + Default is None. Returns ------- transform : Callable """ + if transform_kwargs is None: + transform_kwargs = {} transforms = [ vak_transforms.ToFloatTensor(), vak_transforms.AddChannel(), diff --git a/src/vak/transforms/frame_labels/functional.py b/src/vak/transforms/frame_labels/functional.py index 0ea758840..7fd73ff30 100644 --- a/src/vak/transforms/frame_labels/functional.py +++ b/src/vak/transforms/frame_labels/functional.py @@ -17,6 +17,7 @@ and apply the most "popular" label within each segment to all timebins in that segment - postprocess: combines remove_short_segments and take_majority_vote in one transform """ + from __future__ import annotations import numpy as np diff --git a/src/vak/transforms/frame_labels/transforms.py b/src/vak/transforms/frame_labels/transforms.py index 2734b7da0..bcb81bc48 100644 --- a/src/vak/transforms/frame_labels/transforms.py +++ b/src/vak/transforms/frame_labels/transforms.py @@ -20,6 +20,7 @@ - PostProcess: combines two post-processing transforms applied to frame labels, ``remove_short_segments`` and ``take_majority_vote``, in one class. """ + from __future__ import annotations import numpy as np diff --git a/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml index a2be49143..b38979f48 100644 --- a/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "parametric umap" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412" @@ -8,20 +8,19 @@ annot_format = "notmat" labelset = "iabcdefghjk" test_dur = 0.2 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 32 transform_type = "log_spect_plus_one" -[EVAL] +[vak.eval] checkpoint_path = "tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/ConvEncoderUMAP/results_230727_210112/ConvEncoderUMAP/checkpoints/checkpoint.pt" -model = "ConvEncoderUMAP" batch_size = 64 num_workers = 16 device = "cuda" output_dir = "./tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/ConvEncoderUMAP" -[ConvEncoderUMAP.network] +[vak.eval.model.ConvEncoderUMAP.network] conv1_filters = 8 conv2_filters = 16 conv_kernel_size = 3 @@ -30,5 +29,5 @@ conv_padding = 1 n_features_linear = 32 n_components = 2 -[ConvEncoderUMAP.optimizer] +[vak.eval.model.ConvEncoderUMAP.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml index 456c99468..8be5a4d3a 100644 --- a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "parametric umap" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312" @@ -10,13 +10,12 @@ train_dur = 0.5 val_dur = 0.2 test_dur = 0.25 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 32 transform_type = "log_spect_plus_one" -[TRAIN] -model = "ConvEncoderUMAP" +[vak.train] batch_size = 64 num_epochs = 1 val_step = 1 @@ -25,7 +24,7 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/ConvEncoderUMAP" -[ConvEncoderUMAP.network] +[vak.train.model.ConvEncoderUMAP.network] conv1_filters = 8 conv2_filters = 16 conv_kernel_size = 3 @@ -34,5 +33,5 @@ conv_padding = 1 n_features_linear = 32 n_components = 2 -[ConvEncoderUMAP.optimizer] +[vak.train.model.ConvEncoderUMAP.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml index 51f5157e4..12bfcba84 100644 --- a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" labelset = "iabcdefghjk" @@ -7,31 +7,30 @@ output_dir = "./tests/data_for_tests/generated/prep/eval/audio_cbin_annot_notmat audio_format = "cbin" annot_format = "notmat" -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[EVAL] +[vak.eval] checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json" -model = "TweetyNet" batch_size = 11 num_workers = 16 device = "cuda" spect_scaler_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect" output_dir = "./tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet" -[EVAL.post_tfm_kwargs] +[vak.eval.post_tfm_kwargs] majority_vote = true min_segment_dur = 0.02 -[EVAL.transform_params] -window_size = 88 +[vak.eval.dataset] +params = { window_size = 88 } -[TweetyNet.network] +[vak.eval.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -42,5 +41,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.eval.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml index 0922283e8..59868a28a 100644 --- a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312" @@ -12,15 +12,14 @@ test_dur = 30 train_set_durs = [ 4, 6,] num_replicates = 2 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[LEARNCURVE] -model = "TweetyNet" +[vak.learncurve] normalize_spectrograms = true batch_size = 11 num_epochs = 2 @@ -31,17 +30,14 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat/TweetyNet" -[LEARNCURVE.post_tfm_kwargs] +[vak.learncurve.post_tfm_kwargs] majority_vote = true min_segment_dur = 0.02 -[LEARNCURVE.train_dataset_params] -window_size = 88 +[vak.learncurve.dataset] +params = { window_size = 88 } -[LEARNCURVE.val_transform_params] -window_size = 88 - -[TweetyNet.network] +[vak.learncurve.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -52,5 +48,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.learncurve.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml index da6a9175c..3d794f314 100644 --- a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml @@ -1,32 +1,31 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412" output_dir = "./tests/data_for_tests/generated/prep/predict/audio_cbin_annot_notmat/TweetyNet" spect_format = "npz" -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[PREDICT] +[vak.predict] spect_scaler_path = "/home/user/results_181014_194418/spect_scaler" checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json" -model = "TweetyNet" batch_size = 11 num_workers = 16 device = "cuda" output_dir = "./tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet" annot_csv_filename = "bl26lb16.041912.annot.csv" -[PREDICT.transform_params] -window_size = 88 +[vak.predict.dataset] +params = { window_size = 88 } -[TweetyNet.network] +[vak.predict.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -37,5 +36,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.predict.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml index 2f72adfb1..9b751e7f0 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312" @@ -10,15 +10,14 @@ train_dur = 50 val_dur = 15 test_dur = 30 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[TRAIN] -model = "TweetyNet" +[vak.train] normalize_spectrograms = true batch_size = 11 num_epochs = 2 @@ -29,13 +28,10 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet" -[TRAIN.train_dataset_params] -window_size = 88 +[vak.train.dataset] +params = { window_size = 88 } -[TRAIN.val_transform_params] -window_size = 88 - -[TweetyNet.network] +[vak.train.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -46,5 +42,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml index 932208616..c7ca91a96 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312" @@ -10,15 +10,14 @@ train_dur = 50 val_dur = 15 test_dur = 30 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[TRAIN] -model = "TweetyNet" +[vak.train] normalize_spectrograms = true batch_size = 11 num_epochs = 2 @@ -31,13 +30,10 @@ root_results_dir = "./tests/data_for_tests/generated/results/train_continue/audi checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" spect_scaler_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect" -[TRAIN.train_dataset_params] -window_size = 88 +[vak.train.dataset] +params = { window_size = 88 } -[TRAIN.val_transform_params] -window_size = 88 - -[TweetyNet.network] +[vak.train.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -48,5 +44,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml index aa384b6ed..c66e9c34d 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/spect_mat_annot_yarden/llb3/spect" @@ -10,15 +10,14 @@ labelset = "range: 1-3,6-14,17-19" train_dur = 213 val_dur = 213 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[TRAIN] -model = "TweetyNet" +[vak.train] normalize_spectrograms = false batch_size = 11 num_epochs = 2 @@ -30,13 +29,10 @@ device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train_continue/spect_mat_annot_yarden/TweetyNet" checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" -[TRAIN.train_dataset_params] -window_size = 88 +[vak.train.dataset] +params = { window_size = 88 } -[TRAIN.val_transform_params] -window_size = 88 - -[TweetyNet.network] +[vak.train.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -47,5 +43,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml index 770012f4f..a9aaaf112 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/spect_mat_annot_yarden/llb3/spect" @@ -10,15 +10,14 @@ labelset = "range: 1-3,6-14,17-19" train_dur = 213 val_dur = 213 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[TRAIN] -model = "TweetyNet" +[vak.train] normalize_spectrograms = false batch_size = 11 num_epochs = 2 @@ -29,13 +28,10 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/spect_mat_annot_yarden/TweetyNet" -[TRAIN.train_dataset_params] -window_size = 88 +[vak.train.dataset] +params = { window_size = 88 } -[TRAIN.val_transform_params] -window_size = 88 - -[TweetyNet.network] +[vak.train.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -46,5 +42,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/invalid_option_config.toml b/tests/data_for_tests/configs/invalid_key_config.toml similarity index 61% rename from tests/data_for_tests/configs/invalid_option_config.toml rename to tests/data_for_tests/configs/invalid_key_config.toml index 5504fbf38..0012c6d6c 100644 --- a/tests/data_for_tests/configs/invalid_option_config.toml +++ b/tests/data_for_tests/configs/invalid_key_config.toml @@ -1,11 +1,12 @@ -# used to test that invalid option 'ouput_dir' (instead of 'output_dir') +# used to test that invalid key 'ouput_dir' (instead of 'output_dir') # raises a ValueError when passed to -# vak.config.validators.are_options_valid -[PREP] +# vak.config.validators.are_keys_valid +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = '/home/user/data/subdir/' -ouput_dir = '/why/do/i/keep/typing/ouput' # <-- invalid option 'ouput' instead of 'output' +# next line, invalid key 'ouput' instead of 'output' +ouput_dir = '/why/do/i/keep/typing/ouput' audio_format = 'cbin' annot_format = 'notmat' labelset = 'iabcdefghjk' @@ -20,8 +21,7 @@ freq_cutoffs = [500, 10000] thresh = 6.25 transform_type = 'log_spect' -[TRAIN] -model = 'TweetyNet' +[vak.train] root_results_dir = '/home/user/data/subdir/' normalize_spectrograms = true num_epochs = 2 @@ -30,8 +30,8 @@ val_error_step = 1 checkpoint_step = 1 save_only_single_checkpoint_file = true -[TRAIN.dataset_params] -window_size = 88 +[vak.train.dataset] +params = { window_size = 88 } -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] learning_rate = 0.001 diff --git a/tests/data_for_tests/configs/invalid_section_config.toml b/tests/data_for_tests/configs/invalid_table_config.toml similarity index 72% rename from tests/data_for_tests/configs/invalid_section_config.toml rename to tests/data_for_tests/configs/invalid_table_config.toml index f77cde3a3..24998129d 100644 --- a/tests/data_for_tests/configs/invalid_section_config.toml +++ b/tests/data_for_tests/configs/invalid_table_config.toml @@ -1,7 +1,7 @@ -# used to test that invalid section 'TRIAN' (instead of 'TRAIN') +# used to test that invalid section 'TRIAN' (instead of 'vak.train') # raises a ValueError when passed to # vak.config.validators.are_sections_valid -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = '/home/user/data/subdir/' @@ -13,14 +13,14 @@ train_dur = 10 val_dur = 5 test_dur = 10 -[SPECTROGRAM] +[vak.prep.spect_params] fft_size=512 step_size=64 freq_cutoffs = [500, 10000] thresh = 6.25 transform_type = 'log_spect' -[TRIAN] # <-- invalid section 'TRIAN' (instead of 'TRAIN') +[vak.trian] # <-- invalid section 'trian' (instead of 'vak.train') model = 'TweetyNet' root_results_dir = '/home/user/data/subdir/' normalize_spectrograms = true @@ -30,5 +30,5 @@ val_error_step = 1 checkpoint_step = 1 save_only_single_checkpoint_file = true -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] learning_rate = 0.001 diff --git a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml index ce13bb316..a4fcd542d 100644 --- a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml +++ b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/cbins/gy6or6/032312" @@ -10,7 +10,7 @@ train_dur = 50 val_dur = 15 test_dur = 30 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size=512 step_size=64 freq_cutoffs = [500, 10000] @@ -18,9 +18,8 @@ thresh = 6.25 transform_type = "log_spect" # this .toml file should cause 'vak.config.parse.from_toml' to raise a ValueError -# because it defines both a TRAIN and a LEARNCURVE section -[TRAIN] -model = "TweetyNet" +# because it defines both a vak.train and a vak.learncurve section +[vak.train] normalize_spectrograms = true batch_size = 11 num_epochs = 2 @@ -31,8 +30,7 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat" -[LEARNCURVE] -model = 'TweetyNet' +[vak.learncurve] normalize_spectrograms = true batch_size = 11 num_epochs = 2 @@ -45,5 +43,5 @@ num_replicates = 2 device = "cuda" root_results_dir = './tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat' -[TweetyNet.optimizer] +[vak.learncurve.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/fixtures/annot.py b/tests/fixtures/annot.py index 40ba6dec0..d7c63a4af 100644 --- a/tests/fixtures/annot.py +++ b/tests/fixtures/annot.py @@ -1,7 +1,7 @@ """fixtures relating to annotation files""" import crowsetta import pytest -import toml +import tomlkit from .config import GENERATED_TEST_CONFIGS_ROOT @@ -75,8 +75,8 @@ def annot_list_notmat(): )[0] # get first config.toml from glob list # doesn't really matter which config, they all have labelset with a_train_notmat_config.open("r") as fp: - a_train_notmat_toml = toml.load(fp) -LABELSET_NOTMAT = a_train_notmat_toml["PREP"]["labelset"] + a_train_notmat_toml = tomlkit.load(fp) +LABELSET_NOTMAT = list(str(a_train_notmat_toml["vak"]["prep"]["labelset"])) @pytest.fixture @@ -135,4 +135,5 @@ def annotated_annot_no_segments(request): Used to test edge case for `has_unlabeled`, see https://github.com/vocalpy/vak/issues/378 """ + return request.param diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index dae6e50f4..8fd9b8dd9 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -3,7 +3,7 @@ import shutil import pytest -import toml +import tomlkit from .test_data import GENERATED_TEST_DATA_ROOT, TEST_DATA_ROOT @@ -19,7 +19,7 @@ def test_configs_root(): 1) those used by the tests/scripts/generate_data_for_tests.py script. Will be listed in configs.json. See ``specific_config_toml_path`` fixture below for details about types of configs. - 2) those used by tests that are static, e.g., ``invalid_section_config.toml`` + 2) those used by tests that are static, e.g., ``invalid_table_config.toml`` This fixture facilitates access to type (2), e.g. in test_config/test_parse """ @@ -61,30 +61,26 @@ def config_that_doesnt_exist(tmp_path): @pytest.fixture -def invalid_section_config_path(test_configs_root): - return test_configs_root.joinpath("invalid_section_config.toml") +def invalid_table_config_path(test_configs_root): + return test_configs_root.joinpath("invalid_table_config.toml") @pytest.fixture -def invalid_option_config_path(test_configs_root): - return test_configs_root.joinpath("invalid_option_config.toml") - - -GENERATED_TEST_CONFIGS_ROOT = GENERATED_TEST_DATA_ROOT.joinpath("configs") +def invalid_key_config_path(test_configs_root): + return test_configs_root.joinpath("invalid_key_config.toml") @pytest.fixture -def generated_test_configs_root(): - return GENERATED_TEST_CONFIGS_ROOT +def invalid_train_and_learncurve_config_toml(test_configs_root): + return test_configs_root.joinpath("invalid_train_and_learncurve_config.toml") -ALL_GENERATED_CONFIGS = sorted(GENERATED_TEST_CONFIGS_ROOT.glob("*toml")) +GENERATED_TEST_CONFIGS_ROOT = GENERATED_TEST_DATA_ROOT.joinpath("configs") -# ---- path to config files ---- @pytest.fixture -def all_generated_configs(): - return ALL_GENERATED_CONFIGS +def generated_test_configs_root(): + return GENERATED_TEST_CONFIGS_ROOT @pytest.fixture @@ -101,7 +97,7 @@ def specific_config_toml_path(generated_test_configs_root, list_of_schematized_c If ``root_results_dir`` argument is specified when calling the factory function, - it will convert the value for that option in the section + it will convert the value for that key in the table corresponding to ``config_type`` to the value specified for ``root_results_dir``. This makes it possible to dynamically change the ``root_results_dir`` @@ -114,7 +110,7 @@ def _specific_config( annot_format, audio_format=None, spect_format=None, - options_to_change=None, + keys_to_change=None, ): """returns path to a specific configuration file, determined by characteristics specified by the caller: @@ -128,18 +124,18 @@ def _specific_config( annotation format, recognized by ``crowsetta`` audio_format : str spect_format : str - options_to_change : list, dict - list of dicts with keys 'section', 'option', and 'value'. - Can be a single dict, in which case only that option is changed. - If the 'value' is set to 'DELETE-OPTION', - the option will be removed from the config. - This can be used to test behavior when the option is not set. + keys_to_change : list, dict + list of dicts with keys 'table', 'key', and 'value'. + Can be a single dict, in which case only that key is changed. + If the 'value' is set to 'DELETE-KEY', + the key will be removed from the config. + This can be used to test behavior when the key is not set. Returns ------- config_path : pathlib.Path that points to temporary copy of specified config, - with any options changed as specified + with any keys changed as specified """ original_config_path = None for schematized_config in list_of_schematized_configs: @@ -166,63 +162,100 @@ def _specific_config( config_copy_path = tmp_path.joinpath(original_config_path.name) config_copy_path = shutil.copy(src=original_config_path, dst=config_copy_path) - if options_to_change is not None: - if isinstance(options_to_change, dict): - options_to_change = [options_to_change] - elif isinstance(options_to_change, list): + if keys_to_change is not None: + if isinstance(keys_to_change, dict): + keys_to_change = [keys_to_change] + elif isinstance(keys_to_change, list): pass else: raise TypeError( - f"invalid type for `options_to_change`: {type(options_to_change)}" + f"invalid type for `keys_to_change`: {type(keys_to_change)}" ) with config_copy_path.open("r") as fp: - config_toml = toml.load(fp) - - for opt_dict in options_to_change: - if opt_dict["value"] == 'DELETE-OPTION': - # e.g., to test behavior of config without this option - del config_toml[opt_dict["section"]][opt_dict["option"]] + tomldoc = tomlkit.load(fp) + + for table_key_val_dict in keys_to_change: + table_name = table_key_val_dict["table"] + key = table_key_val_dict["key"] + value = table_key_val_dict["value"] + if isinstance(key, str): + if table_key_val_dict["value"] == 'DELETE-KEY': + # e.g., to test behavior of config without this key + del tomldoc["vak"][table_name][key] + else: + tomldoc["vak"][table_name][key] = value + elif isinstance(key, list) and len(key) == 2 and all([isinstance(el, str) for el in key]): + # for the case where we need to access a sub-table + # right now this applies mainly to ["vak"][table]["dataset"]["path"] + # if we end up having to access more / deeper then we'll need something more general + if table_key_val_dict["value"] == 'DELETE-KEY': + # e.g., to test behavior of config without this key + del tomldoc["vak"][table_name][key[0]][key[1]] + else: + tomldoc["vak"][table_name][key[0]][key[1]] = value else: - config_toml[opt_dict["section"]][opt_dict["option"]] = opt_dict["value"] + raise ValueError( + f"Unexpected value for 'key' in `keys_to_change` dict: {key}.\n" + f"`keys_to_change` dict: {table_key_val_dict}" + ) with config_copy_path.open("w") as fp: - toml.dump(config_toml, fp) + tomlkit.dump(tomldoc, fp) return config_copy_path return _specific_config -@pytest.fixture -def all_generated_train_configs(generated_test_configs_root): - return sorted(generated_test_configs_root.glob("test_train*toml")) - - -ALL_GENERATED_LEARNCURVE_CONFIGS = sorted(GENERATED_TEST_CONFIGS_ROOT.glob("test_learncurve*toml")) +ALL_GENERATED_CONFIG_PATHS = sorted(GENERATED_TEST_CONFIGS_ROOT.glob("*toml")) -@pytest.fixture -def all_generated_learncurve_configs(generated_test_configs_root): - return ALL_GENERATED_LEARNCURVE_CONFIGS - - -@pytest.fixture -def all_generated_eval_configs(generated_test_configs_root): - return sorted(generated_test_configs_root.glob("test_eval*toml")) +# ---- path to config files ---- +@pytest.fixture(params=ALL_GENERATED_CONFIG_PATHS) +def a_generated_config_path(request): + return request.param -@pytest.fixture -def all_generated_predict_configs(generated_test_configs_root): - return sorted(generated_test_configs_root.glob("test_predict*toml")) +def _tomlkit_to_popo(d): + """Convert tomlkit to "popo" (Plain-Old Python Objects) -# ---- config toml from paths ---- -def _return_toml(toml_path): - """return config files loaded into dicts with toml library - used to test functions that parse config sections, taking these dicts as inputs""" + From https://github.com/python-poetry/tomlkit/issues/43#issuecomment-660415820 + """ + try: + result = getattr(d, "value") + except AttributeError: + result = d + + if isinstance(result, list): + result = [_tomlkit_to_popo(x) for x in result] + elif isinstance(result, dict): + result = { + _tomlkit_to_popo(key): _tomlkit_to_popo(val) for key, val in result.items() + } + elif isinstance(result, tomlkit.items.Integer): + result = int(result) + elif isinstance(result, tomlkit.items.Float): + result = float(result) + elif isinstance(result, tomlkit.items.String): + result = str(result) + elif isinstance(result, tomlkit.items.Bool): + result = bool(result) + + return result + + +# ---- config dicts from paths ---- +def _load_config_dict(toml_path): + """Return config as dict, loaded from toml file. + + Used to test functions that parse config tables, taking these dicts as inputs. + + Note that we access the topmost table loaded from the toml: config_dict['vak'] + """ with toml_path.open("r") as fp: - config_toml = toml.load(fp) - return config_toml + config_dict = tomlkit.load(fp) + return _tomlkit_to_popo(config_dict['vak']) @pytest.fixture @@ -244,42 +277,70 @@ def _specific_config_toml( config_path = specific_config_toml_path( config_type, model, annot_format, audio_format, spect_format ) - return _return_toml(config_path) + return _load_config_dict(config_path) return _specific_config_toml -ALL_GENERATED_CONFIGS_TOML = [_return_toml(config) for config in ALL_GENERATED_CONFIGS] +@pytest.fixture(params=ALL_GENERATED_CONFIG_PATHS) +def a_generated_config_dict(request): + # we remake dict every time this gets called + # so that we're not returning a ``config_dict`` that was + # already mutated by a `Config.from_config_dict` function, + # e.g. the value for the 'spect_params' key gets mapped to a SpectParamsConfig + # by PrepConfig.from_config_dict + return _load_config_dict(request.param) -@pytest.fixture -def all_generated_configs_toml(): - return ALL_GENERATED_CONFIGS_TOML +ALL_GENERATED_EVAL_CONFIG_PATHS = sorted( + GENERATED_TEST_CONFIGS_ROOT.glob("*eval*toml") +) +ALL_GENERATED_LEARNCURVE_CONFIG_PATHS = sorted( + GENERATED_TEST_CONFIGS_ROOT.glob("*learncurve*toml") +) -@pytest.fixture -def all_generated_train_configs_toml(all_generated_train_configs): - return [_return_toml(config) for config in all_generated_train_configs] +ALL_GENERATED_PREDICT_CONFIG_PATHS = sorted( + GENERATED_TEST_CONFIGS_ROOT.glob("*predict*toml") +) +ALL_GENERATED_TRAIN_CONFIG_PATHS = sorted( + GENERATED_TEST_CONFIGS_ROOT.glob("*train*toml") +) -@pytest.fixture -def all_generated_learncurve_configs_toml(all_generated_learncurve_configs): - return [_return_toml(config) for config in all_generated_learncurve_configs] +# as above, we remake dict every time these fixutres get called +# so that we're not returning a ``config_dict`` that was +# already mutated by a `Config.from_config_dict` function, +# e.g. the value for the 'spect_params' key gets mapped to a SpectParamsConfig +# by PrepConfig.from_config_dict +@pytest.fixture(params=ALL_GENERATED_EVAL_CONFIG_PATHS) +def a_generated_eval_config_dict(request): + return _load_config_dict(request.param) -@pytest.fixture -def all_generated_eval_configs_toml(all_generated_eval_configs): - return [_return_toml(config) for config in all_generated_eval_configs] +@pytest.fixture(params=ALL_GENERATED_LEARNCURVE_CONFIG_PATHS) +def a_generated_learncurve_config_dict(request): + return _load_config_dict(request.param) + + +@pytest.fixture(params=ALL_GENERATED_PREDICT_CONFIG_PATHS) +def a_generated_predict_config_dict(request): + return _load_config_dict(request.param) + + +@pytest.fixture(params=ALL_GENERATED_TRAIN_CONFIG_PATHS) +def a_generated_train_config_dict(request): + return _load_config_dict(request.param) @pytest.fixture -def all_generated_predict_configs_toml(all_generated_predict_configs): - return [_return_toml(config) for config in all_generated_predict_configs] +def all_generated_learncurve_configs_toml(all_generated_learncurve_configs): + return [_load_config_dict(config) for config in all_generated_learncurve_configs] ALL_GENERATED_CONFIGS_TOML_PATH_PAIRS = list(zip( - [_return_toml(config) for config in ALL_GENERATED_CONFIGS], - ALL_GENERATED_CONFIGS, + [_load_config_dict(config) for config in ALL_GENERATED_CONFIG_PATHS], + ALL_GENERATED_CONFIG_PATHS, )) @@ -293,10 +354,10 @@ def all_generated_configs_toml_path_pairs(): """ # we duplicate the constant above because we need to remake # the variables for each unit test. Otherwise tests that modify values - # for config options cause other tests to fail + # for config keys cause other tests to fail return zip( - [_return_toml(config) for config in ALL_GENERATED_CONFIGS], - ALL_GENERATED_CONFIGS + [_load_config_dict(config) for config in ALL_GENERATED_CONFIG_PATHS], + ALL_GENERATED_CONFIG_PATHS ) @@ -307,13 +368,13 @@ def configs_toml_path_pairs_by_model_factory(all_generated_configs_toml_path_pai """ def _wrapped(model, - section_name=None): + table_name=None): out = [] unzipped = list(all_generated_configs_toml_path_pairs) for config_toml, toml_path in unzipped: if toml_path.name.startswith(model): - if section_name: - if section_name.lower() in toml_path.name: + if table_name: + if table_name.lower() in toml_path.name: out.append( (config_toml, toml_path) ) @@ -325,54 +386,3 @@ def _wrapped(model, return _wrapped - -@pytest.fixture -def all_generated_train_configs_toml_path_pairs(all_generated_train_configs): - """zip of tuple pairs: (dict, pathlib.Path) - where ``Path`` is path to .toml config file and ``dict`` is - the .toml config from that path - loaded into a dict with the ``toml`` library - """ - return zip( - [_return_toml(config) for config in all_generated_train_configs], - all_generated_train_configs, - ) - - -@pytest.fixture -def all_generated_learncurve_configs_toml_path_pairs(all_generated_learncurve_configs): - """zip of tuple pairs: (dict, pathlib.Path) - where ``Path`` is path to .toml config file and ``dict`` is - the .toml config from that path - loaded into a dict with the ``toml`` library - """ - return zip( - [_return_toml(config) for config in all_generated_learncurve_configs], - all_generated_learncurve_configs, - ) - - -@pytest.fixture -def all_generated_eval_configs_toml_path_pairs(all_generated_eval_configs): - """zip of tuple pairs: (dict, pathlib.Path) - where ``Path`` is path to .toml config file and ``dict`` is - the .toml config from that path - loaded into a dict with the ``toml`` library - """ - return zip( - [_return_toml(config) for config in all_generated_eval_configs], - all_generated_eval_configs, - ) - - -@pytest.fixture -def all_generated_predict_configs_toml_path_pairs(all_generated_predict_configs): - """zip of tuple pairs: (dict, pathlib.Path) - where ``Path`` is path to .toml config file and ``dict`` is - the .toml config from that path - loaded into a dict with the ``toml`` library - """ - return zip( - [_return_toml(config) for config in all_generated_predict_configs], - all_generated_predict_configs, - ) diff --git a/tests/fixtures/csv.py b/tests/fixtures/csv.py index e4ef71761..0e21cc196 100644 --- a/tests/fixtures/csv.py +++ b/tests/fixtures/csv.py @@ -25,11 +25,11 @@ def _specific_csv_path( config_toml = specific_config_toml( config_type, model, annot_format, audio_format, spect_format ) - dataset_path = Path(config_toml[config_type.upper()]["dataset_path"]) + dataset_path = Path(config_toml[config_type]["dataset"]["path"]) # TODO: make this more general -- dataset registry? - if config_toml['PREP']['dataset_type'] == 'frame classification': + if config_toml['prep']['dataset_type'] == 'frame classification': metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) - elif config_toml['PREP']['dataset_type'] == 'parametric umap': + elif config_toml['prep']['dataset_type'] == 'parametric umap': metadata = vak.datasets.parametric_umap.Metadata.from_dataset_path(dataset_path) dataset_csv_path = dataset_path / metadata.dataset_csv_filename return dataset_csv_path diff --git a/tests/fixtures/dataset.py b/tests/fixtures/dataset.py index 51d1344b0..cb4382269 100644 --- a/tests/fixtures/dataset.py +++ b/tests/fixtures/dataset.py @@ -22,7 +22,7 @@ def _specific_dataset_path( config_toml = specific_config_toml( config_type, model, annot_format, audio_format, spect_format ) - dataset_path = Path(config_toml[config_type.upper()]["dataset_path"]) + dataset_path = Path(config_toml[config_type]["dataset"]["path"]) return dataset_path return _specific_dataset_path diff --git a/tests/scripts/vaktestdata/configs.py b/tests/scripts/vaktestdata/configs.py index cde39be49..3134f7f57 100644 --- a/tests/scripts/vaktestdata/configs.py +++ b/tests/scripts/vaktestdata/configs.py @@ -3,8 +3,7 @@ import pathlib import shutil -# TODO: use tomli -import toml +import tomlkit import vak.cli.prep from . import constants @@ -49,10 +48,9 @@ def add_dataset_path_from_prepped_configs(): """This helper function goes through all configs in :data:`vaktestdata.constants.CONFIG_METADATA` and for any that have a filename for the attribute - "use_dataset_from_config", it sets the option 'dataset_path' + "use_dataset_from_config", it sets the key 'path' in the 'dataset' table in the config file that the metadata corresponds to - to the same option from the file specified - by the attribute. + to the same value from the file specified by the attribute. """ configs_to_change = [ config_metadata @@ -63,27 +61,30 @@ def add_dataset_path_from_prepped_configs(): for config_metadata in configs_to_change: config_to_change_path = constants.GENERATED_TEST_CONFIGS_ROOT / config_metadata.filename if config_metadata.config_type == 'train_continue': - section = 'TRAIN' + table_to_add_dataset = 'train' else: - section = config_metadata.config_type.upper() + table_to_add_dataset = config_metadata.config_type config_dataset_path = constants.GENERATED_TEST_CONFIGS_ROOT / config_metadata.use_dataset_from_config - with config_dataset_path.open("r") as fp: - dataset_config_toml = toml.load(fp) - purpose = vak.cli.prep.purpose_from_toml(dataset_config_toml) + config_dict = vak.config.load._load_toml_from_path(config_dataset_path) # next line, we can't use `section` here because we could get a KeyError, - # e.g., when the config we are rewriting is an EVAL config, but - # the config we are getting the dataset from is a TRAIN config. + # e.g., when the config we are rewriting is an ``[vak.eval]`` config, but + # the config we are getting the dataset from is a ``[vak.train]`` config. # so instead we use `purpose_from_toml` to get the `purpose` # of the config we are getting the dataset from. - dataset_config_section = purpose.upper() # need to be 'TRAIN', not 'train' - dataset_path = dataset_config_toml[dataset_config_section]['dataset_path'] - with config_to_change_path.open("r") as fp: - config_to_change_toml = toml.load(fp) - config_to_change_toml[section]['dataset_path'] = dataset_path + dataset_config_section = vak.cli.prep.purpose_from_toml(config_dict) + dataset_path = config_dict[dataset_config_section]['dataset']['path'] + + # we open config using tomlkit so we can add path to dataset table in style-preserving way + with config_to_change_path.open('r') as fp: + tomldoc = tomlkit.load(fp) + if 'dataset' not in tomldoc['vak'][table_to_add_dataset]: + dataset_table = tomlkit.table() + tomldoc["vak"][table_to_add_dataset].add("dataset", dataset_table) + tomldoc["vak"][table_to_add_dataset]["dataset"].add("path", str(dataset_path)) with config_to_change_path.open("w") as fp: - toml.dump(config_to_change_toml, fp) + tomlkit.dump(tomldoc, fp) def fix_options_in_configs(config_metadata_list, command, single_train_result=True): @@ -104,8 +105,8 @@ def fix_options_in_configs(config_metadata_list, command, single_train_result=Tr # now use the config to find the results dir and get the values for the options we need to set # which are checkpoint_path, spect_scaler_path, and labelmap_path with config_to_use_result_from.open("r") as fp: - config_toml = toml.load(fp) - root_results_dir = pathlib.Path(config_toml["TRAIN"]["root_results_dir"]) + config_toml = tomlkit.load(fp) + root_results_dir = pathlib.Path(config_toml["vak"]["train"]["root_results_dir"]) results_dir = sorted(root_results_dir.glob("results_*")) if len(results_dir) > 1: if single_train_result: @@ -130,7 +131,7 @@ def fix_options_in_configs(config_metadata_list, command, single_train_result=Tr # these are the only options whose values we need to change # and they are the same for both predict and eval checkpoint_path = sorted(results_dir.glob("**/checkpoints/checkpoint.pt"))[0] - if 'normalize_spectrograms' in config_toml['TRAIN'] and config_toml['TRAIN']['normalize_spectrograms']: + if 'normalize_spectrograms' in config_toml["vak"]['train'] and config_toml["vak"]['train']['normalize_spectrograms']: spect_scaler_path = sorted(results_dir.glob("StandardizeSpect"))[0] else: spect_scaler_path = None @@ -150,23 +151,23 @@ def fix_options_in_configs(config_metadata_list, command, single_train_result=Tr # now add these values to corresponding options in predict / eval config with config_to_fix.open("r") as fp: - config_toml = toml.load(fp) + config_toml = tomlkit.load(fp) if command == 'train_continue': - section = 'TRAIN' + table = 'train' else: - section = command.upper() + table = command - config_toml[section]["checkpoint_path"] = str(checkpoint_path) + config_toml["vak"][table]["checkpoint_path"] = str(checkpoint_path) if spect_scaler_path: - config_toml[section]["spect_scaler_path"] = str(spect_scaler_path) + config_toml["vak"][table]["spect_scaler_path"] = str(spect_scaler_path) else: - if 'spect_scaler_path' in config_toml[section]: + if 'spect_scaler_path' in config_toml["vak"][table]: # remove any existing 'spect_scaler_path' option - del config_toml[section]["spect_scaler_path"] + del config_toml["vak"][table]["spect_scaler_path"] if command != 'train_continue': # train always gets labelmap from dataset dir, not from a config option if labelmap_path is not None: - config_toml[section]["labelmap_path"] = str(labelmap_path) + config_toml["vak"][table]["labelmap_path"] = str(labelmap_path) with config_to_fix.open("w") as fp: - toml.dump(config_toml, fp) + tomlkit.dump(config_toml, fp) diff --git a/tests/scripts/vaktestdata/source_files.py b/tests/scripts/vaktestdata/source_files.py index e53d0e2ee..d1f25bc4e 100644 --- a/tests/scripts/vaktestdata/source_files.py +++ b/tests/scripts/vaktestdata/source_files.py @@ -7,7 +7,7 @@ warnings.simplefilter('ignore', category=NumbaDeprecationWarning) import pandas as pd -import toml +import tomlkit import vak @@ -47,14 +47,14 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): f"\nRunning :func:`vak.prep.frame_classification.get_or_make_source_files` to generate data for tests, " f"using config:\n{config_path.name}" ) - cfg = vak.config.parse.from_toml_path(config_path) + cfg = vak.config.Config.from_toml_path(config_path, tables_to_parse='prep') source_files_df: pd.DataFrame = vak.prep.frame_classification.get_or_make_source_files( data_dir=cfg.prep.data_dir, input_type=cfg.prep.input_type, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, spect_output_dir=spect_output_dir, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, @@ -72,7 +72,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): csv_path = constants.GENERATED_SOURCE_FILES_CSV_DIR / f'{config_metadata.filename}-source-files.csv' source_files_df.to_csv(csv_path, index=False) - config_toml: dict = vak.config.parse._load_toml_from_path(config_path) + config_toml: dict = vak.config.load._load_toml_from_path(config_path) purpose = vak.cli.prep.purpose_from_toml(config_toml, config_path) dataset_df: pd.DataFrame = vak.prep.frame_classification.assign_samples_to_splits( purpose, @@ -103,20 +103,20 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): ) with config_path.open("r") as fp: - config_toml = toml.load(fp) + tomldoc = tomlkit.load(fp) data_dir = constants.GENERATED_TEST_DATA_ROOT / config_metadata.data_dir - config_toml['PREP']['data_dir'] = str(data_dir) + tomldoc['vak']['prep']['data_dir'] = str(data_dir) with config_path.open("w") as fp: - toml.dump(config_toml, fp) + tomlkit.dump(tomldoc, fp) - cfg = vak.config.parse.from_toml_path(config_path) + cfg = vak.config.Config.from_toml_path(config_path, tables_to_parse='prep') source_files_df: pd.DataFrame = vak.prep.frame_classification.get_or_make_source_files( data_dir=cfg.prep.data_dir, input_type=cfg.prep.input_type, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, spect_output_dir=None, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, @@ -127,7 +127,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): csv_path = constants.GENERATED_SOURCE_FILES_CSV_DIR / f'{config_metadata.filename}-source-files.csv' source_files_df.to_csv(csv_path, index=False) - config_toml: dict = vak.config.parse._load_toml_from_path(config_path) + config_toml: dict = vak.config.load._load_toml_from_path(config_path) purpose = vak.cli.prep.purpose_from_toml(config_toml, config_path) dataset_df: pd.DataFrame = vak.prep.frame_classification.assign_samples_to_splits( purpose, @@ -159,13 +159,13 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): f"\nRunning :func:`vak.prep.frame_classification.get_or_make_source_files` to generate data for tests, " f"using config:\n{config_path.name}" ) - cfg = vak.config.parse.from_toml_path(config_path) + cfg = vak.config.Config.from_toml_path(config_path, tables_to_parse='prep') source_files_df: pd.DataFrame = vak.prep.frame_classification.get_or_make_source_files( data_dir=cfg.prep.data_dir, input_type=cfg.prep.input_type, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, spect_output_dir=None, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, @@ -176,7 +176,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): csv_path = constants.GENERATED_SOURCE_FILES_CSV_DIR / f'{config_metadata.filename}-source-files.csv' source_files_df.to_csv(csv_path, index=False) - config_toml: dict = vak.config.parse._load_toml_from_path(config_path) + config_toml: dict = vak.config.load._load_toml_from_path(config_path) purpose = vak.cli.prep.purpose_from_toml(config_toml, config_path) dataset_df: pd.DataFrame = vak.prep.frame_classification.assign_samples_to_splits( purpose, diff --git a/tests/test_cli/test_eval.py b/tests/test_cli/test_eval.py index f94f68f46..0ee9aba65 100644 --- a/tests/test_cli/test_eval.py +++ b/tests/test_cli/test_eval.py @@ -25,9 +25,9 @@ def test_eval( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": device}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -36,7 +36,7 @@ def test_eval( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with mock.patch('vak.eval.eval', autospec=True) as mock_core_eval: @@ -48,14 +48,14 @@ def test_eval( assert cli_asserts.log_file_contains_version(command="eval", output_path=output_dir) -def test_eval_dataset_path_none_raises( - specific_config_toml_path, tmp_path, +def test_eval_dataset_none_raises( + specific_config_toml_path ): - """Test that cli.eval raises ValueError when dataset_path is None + """Test that cli.eval raises ValueError when dataset is None (presumably because `vak prep` was not run yet) """ - options_to_change = [ - {"section": "EVAL", "option": "dataset_path", "value": "DELETE-OPTION"}, + keys_to_change = [ + {"table": "eval", "key": "dataset", "value": "DELETE-KEY"}, ] toml_path = specific_config_toml_path( @@ -64,8 +64,8 @@ def test_eval_dataset_path_none_raises( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - with pytest.raises(ValueError): + with pytest.raises(KeyError): vak.cli.eval.eval(toml_path) diff --git a/tests/test_cli/test_learncurve.py b/tests/test_cli/test_learncurve.py index 8dce64302..7fd0a3a8b 100644 --- a/tests/test_cli/test_learncurve.py +++ b/tests/test_cli/test_learncurve.py @@ -14,13 +14,13 @@ def test_learncurve(specific_config_toml_path, tmp_path, device): root_results_dir = tmp_path.joinpath("test_learncurve_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "LEARNCURVE", - "option": "root_results_dir", + "table": "learncurve", + "key": "root_results_dir", "value": str(root_results_dir), }, - {"section": "LEARNCURVE", "option": "device", "value": device}, + {"table": "learncurve", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -28,7 +28,7 @@ def test_learncurve(specific_config_toml_path, tmp_path, device): model="TweetyNet", audio_format="cbin", annot_format="notmat", - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with mock.patch('vak.learncurve.learning_curve', autospec=True) as mock_core_learning_curve: @@ -44,26 +44,26 @@ def test_learncurve(specific_config_toml_path, tmp_path, device): assert cli_asserts.log_file_contains_version(command="learncurve", output_path=results_path) -def test_learning_curve_dataset_path_none_raises( +def test_learning_curve_dataset_none_raises( specific_config_toml_path, tmp_path, ): """Test that cli.learncurve.learning_curve - raises ValueError when dataset_path is None + raises ValueError when dataset is None (presumably because `vak prep` was not run yet) """ root_results_dir = tmp_path.joinpath("test_learncurve_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "LEARNCURVE", - "option": "root_results_dir", + "table": "learncurve", + "key": "root_results_dir", "value": str(root_results_dir), }, { - "section": "LEARNCURVE", - "option": "dataset_path", - "value": "DELETE-OPTION"}, + "table": "learncurve", + "key": "dataset", + "value": "DELETE-KEY"}, ] toml_path = specific_config_toml_path( @@ -72,8 +72,8 @@ def test_learning_curve_dataset_path_none_raises( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - with pytest.raises(ValueError): + with pytest.raises(KeyError): vak.cli.learncurve.learning_curve(toml_path) diff --git a/tests/test_cli/test_predict.py b/tests/test_cli/test_predict.py index 6269c01d9..30c78d3c5 100644 --- a/tests/test_cli/test_predict.py +++ b/tests/test_cli/test_predict.py @@ -24,9 +24,9 @@ def test_predict( ) output_dir.mkdir() - options_to_change = [ - {"section": "PREDICT", "option": "output_dir", "value": str(output_dir)}, - {"section": "PREDICT", "option": "device", "value": device}, + keys_to_change = [ + {"table": "predict", "key": "output_dir", "value": str(output_dir)}, + {"table": "predict", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -34,7 +34,7 @@ def test_predict( model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with mock.patch('vak.predict.predict', autospec=True) as mock_core_predict: @@ -45,14 +45,14 @@ def test_predict( assert cli_asserts.log_file_contains_version(command="predict", output_path=output_dir) -def test_predict_dataset_path_none_raises( - specific_config_toml_path, tmp_path, +def test_predict_dataset_none_raises( + specific_config_toml_path ): """Test that cli.predict raises ValueError when dataset_path is None (presumably because `vak prep` was not run yet) """ - options_to_change = [ - {"section": "PREDICT", "option": "dataset_path", "value": "DELETE-OPTION"}, + keys_to_change = [ + {"table": "predict", "key": "dataset", "value": "DELETE-KEY"}, ] toml_path = specific_config_toml_path( @@ -61,8 +61,8 @@ def test_predict_dataset_path_none_raises( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - with pytest.raises(ValueError): + with pytest.raises(KeyError): vak.cli.predict.predict(toml_path) diff --git a/tests/test_cli/test_prep.py b/tests/test_cli/test_prep.py index cfdd8453e..88a7c0a35 100644 --- a/tests/test_cli/test_prep.py +++ b/tests/test_cli/test_prep.py @@ -35,7 +35,7 @@ def test_purpose_from_toml( annot_format=annot_format, spect_format=spect_format, ) - config_toml = vak.config.parse._load_toml_from_path(toml_path) + config_toml = vak.config.load._load_toml_from_path(toml_path) vak.cli.prep.purpose_from_toml(config_toml) @@ -64,13 +64,13 @@ def test_prep( ) output_dir.mkdir() - options_to_change = [ - {"section": "PREP", "option": "output_dir", "value": str(output_dir)}, + keys_to_change = [ + {"table": "prep", "key": "output_dir", "value": str(output_dir)}, # need to remove dataset_path option from configs we already ran prep on to avoid error { - "section": config_type.upper(), - "option": "dataset_path", - "value": None, + "table": config_type, + "key": "dataset", + "value": "DELETE-KEY", }, ] toml_path = specific_config_toml_path( @@ -79,7 +79,7 @@ def test_prep( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with mock.patch('vak.prep.prep', autospec=True) as mock_core_prep: @@ -98,23 +98,23 @@ def test_prep( ("train", None, "mat", "yarden"), ], ) -def test_prep_dataset_path_raises( +def test_prep_dataset_raises( config_type, audio_format, spect_format, annot_format, - specific_config_toml_path, + specific_config_toml_path, default_model, tmp_path, - ): + """Test that prep raises a ValueError when the config already has a dataset with a path""" output_dir = tmp_path.joinpath( f"test_prep_{config_type}_{audio_format}_{spect_format}_{annot_format}" ) output_dir.mkdir() - options_to_change = [ - {"section": "PREP", "option": "output_dir", "value": str(output_dir)}, + keys_to_change = [ + {"table": "prep", "key": "output_dir", "value": str(output_dir)}, ] toml_path = specific_config_toml_path( config_type=config_type, @@ -122,7 +122,7 @@ def test_prep_dataset_path_raises( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with pytest.raises(ValueError): diff --git a/tests/test_cli/test_train.py b/tests/test_cli/test_train.py index c59716ff2..a23acab3c 100644 --- a/tests/test_cli/test_train.py +++ b/tests/test_cli/test_train.py @@ -24,13 +24,13 @@ def test_train( root_results_dir = tmp_path.joinpath("test_train_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "TRAIN", - "option": "root_results_dir", + "table": "train", + "key": "root_results_dir", "value": str(root_results_dir), }, - {"section": "TRAIN", "option": "device", "value": device}, + {"table": "train", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -39,7 +39,7 @@ def test_train( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with mock.patch('vak.train.train', autospec=True) as mock_core_train: @@ -63,9 +63,9 @@ def test_train_dataset_path_none_raises( root_results_dir = tmp_path.joinpath("test_train_root_results_dir") root_results_dir.mkdir() - options_to_change = [ - {"section": "TRAIN", "option": "root_results_dir", "value": str(root_results_dir)}, - {"section": "TRAIN", "option": "dataset_path", "value": "DELETE-OPTION"}, + keys_to_change = [ + {"table": "train", "key": "root_results_dir", "value": str(root_results_dir)}, + {"table": "train", "key": "dataset", "value": "DELETE-KEY"}, ] toml_path = specific_config_toml_path( @@ -74,8 +74,8 @@ def test_train_dataset_path_none_raises( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - with pytest.raises(ValueError): + with pytest.raises(KeyError): vak.cli.train.train(toml_path) diff --git a/tests/test_config/__init__.py b/tests/test_config/__init__.py index eb50111ee..66aaa2a8c 100644 --- a/tests/test_config/__init__.py +++ b/tests/test_config/__init__.py @@ -1,4 +1,4 @@ -from . import test_parse +from . import test_load from . import test_predict from . import test_prep from . import test_spect_params diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index ebf814896..04622f1b7 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -1,26 +1,128 @@ +import pytest + import vak.config -def test_config_attrs_class( - all_generated_configs_toml_path_pairs, - default_model, -): - """test that instantiating Config class works as expected""" - for config_toml, toml_path in all_generated_configs_toml_path_pairs: - if default_model not in str(toml_path): - continue # only need to check configs for one model - # also avoids FileNotFoundError on CI - # this is basically the body of the ``config.parse.from_toml`` function. - config_dict = {} - for section_name in list(vak.config.parse.SECTION_CLASSES.keys()): - if section_name in config_toml: - vak.config.validators.are_options_valid( - config_toml, section_name, toml_path - ) - section = vak.config.parse.parse_config_section( - config_toml, section_name, toml_path +class TestConfig: + @pytest.mark.parametrize( + 'tables_to_parse', + [ + None, + 'prep', + ['prep'], + ] + ) + def test_init_with_real_config( + self, a_generated_config_dict, tables_to_parse + ): + """Test that instantiating Config class works as expected""" + # this is basically the body of the ``Config.from_config_dict`` function. + config_kwargs = {} + + if tables_to_parse is None: + for table_name in a_generated_config_dict: + config_kwargs[table_name] = vak.config.config.TABLE_CLASSES_MAP[table_name].from_config_dict( + a_generated_config_dict[table_name] ) - config_dict[section_name.lower()] = section + else: + for table_name in a_generated_config_dict: + if table_name in tables_to_parse: + config_kwargs[table_name] = vak.config.config.TABLE_CLASSES_MAP[table_name].from_config_dict( + a_generated_config_dict[table_name] + ) + + config = vak.config.Config(**config_kwargs) + + assert isinstance(config, vak.config.Config) + # we already test that config loading works for EvalConfig, et al., + # so here we just test that the logic of Config works as expected: + # we should get an attribute for each top-level table that we pass in; + # if we don't pass one in, then its corresponding attribute should be None + for attr in ('eval', 'learncurve', 'predict', 'prep', 'train'): + if tables_to_parse is not None: + if attr in a_generated_config_dict and attr in tables_to_parse: + assert hasattr(config, attr) + else: + assert getattr(config, attr) is None + else: + if attr in a_generated_config_dict: + assert hasattr(config, attr) + + @pytest.mark.parametrize( + 'tables_to_parse', + [ + None, + 'prep', + ['prep'], + ] + ) + def test_from_config_dict_with_real_config( + self, a_generated_config_dict, tables_to_parse + ): + """Test :meth:`Config.from_config_dict`""" + config = vak.config.Config.from_config_dict( + a_generated_config_dict, tables_to_parse=tables_to_parse + ) + + assert isinstance(config, vak.config.Config) + # we already test that config loading works for EvalConfig, et al., + # so here we just test that the logic of Config works as expected: + # we should get an attribute for each top-level table that we pass in; + # if we don't pass one in, then its corresponding attribute should be None + for attr in ('eval', 'learncurve', 'predict', 'prep', 'train'): + if tables_to_parse is not None: + if attr in a_generated_config_dict and attr in tables_to_parse: + assert hasattr(config, attr) + else: + assert getattr(config, attr) is None + else: + if attr in a_generated_config_dict: + assert hasattr(config, attr) + + @pytest.mark.parametrize( + 'tables_to_parse', + [ + None, + 'prep', + ['prep'], + ] + ) + def test_from_toml_path(self, a_generated_config_path, tables_to_parse): + config = vak.config.Config.from_toml_path( + a_generated_config_path, tables_to_parse=tables_to_parse + ) + + assert isinstance(config, vak.config.Config) + + a_generated_config_dict = vak.config.load._load_toml_from_path(a_generated_config_path) + # we already test that config loading works for EvalConfig, et al., + # so here we just test that the logic of Config works as expected: + # we should get an attribute for each top-level table that we pass in; + # if we don't pass one in, then its corresponding attribute should be None + for attr in ('eval', 'learncurve', 'predict', 'prep', 'train'): + if tables_to_parse is not None: + if attr in a_generated_config_dict and attr in tables_to_parse: + assert hasattr(config, attr) + else: + assert getattr(config, attr) is None + else: + if attr in a_generated_config_dict: + assert hasattr(config, attr) + + def test_from_toml_path_raises_when_config_doesnt_exist(self, config_that_doesnt_exist): + with pytest.raises(FileNotFoundError): + vak.config.Config.from_toml_path(config_that_doesnt_exist) + + def test_invalid_table_raises(self, invalid_table_config_path): + with pytest.raises(ValueError): + vak.config.Config.from_toml_path(invalid_table_config_path) + + def test_invalid_key_raises(self, invalid_key_config_path): + with pytest.raises(ValueError): + vak.config.Config.from_toml_path(invalid_key_config_path) - config = vak.config.parse.Config(**config_dict) - assert isinstance(config, vak.config.parse.Config) + def test_mutiple_top_level_tables_besides_prep_raises(self, invalid_train_and_learncurve_config_toml): + """Test that a .toml config with two top-level tables besides ``[vak.prep]`` raises a ValueError + (in this case ``[vak.train]`` and ``[vak.learncurve]``)""" + with pytest.raises(ValueError): + vak.config.Config.from_toml_path(invalid_train_and_learncurve_config_toml) diff --git a/tests/test_config/test_dataset.py b/tests/test_config/test_dataset.py new file mode 100644 index 000000000..ecc9fed0d --- /dev/null +++ b/tests/test_config/test_dataset.py @@ -0,0 +1,133 @@ +import pathlib + +import pytest + +import vak.config.dataset + + +class TestDatasetConfig: + @pytest.mark.parametrize( + 'path, splits_path, name', + [ + # typical use by a user with default split + ('~/user/prepped/dataset', None, None), + # use by a user with a split specified + ('~/user/prepped/dataset', 'spilts/replicate-1.json', None), + # use of a built-in dataset, with a split specified + ('~/datasets/BioSoundSegBench', 'splits/Bengalese-Finch-song-gy6or6-replicate-1.json', 'BioSoundSegBench'), + + ] + ) + def test_init(self, path, splits_path, name): + if name is None and splits_path is None: + dataset_config = vak.config.dataset.DatasetConfig( + path=path + ) + elif name is None: + dataset_config = vak.config.dataset.DatasetConfig( + path=path, + splits_path=splits_path, + ) + else: + dataset_config = vak.config.dataset.DatasetConfig( + name=name, + path=path, + splits_path=splits_path, + ) + assert isinstance(dataset_config, vak.config.dataset.DatasetConfig) + assert dataset_config.path == pathlib.Path(path) + if splits_path is not None: + assert dataset_config.splits_path == pathlib.Path(splits_path) + else: + assert dataset_config.splits_path is None + if name is not None: + assert dataset_config.name == name + else: + assert dataset_config.name is None + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'path' :'~/datasets/BioSoundSegBench', + 'splits_path': 'splits/Bengalese-Finch-song-gy6or6-replicate-1.json', + 'name': 'BioSoundSegBench', + }, + { + 'path' :'~/user/prepped/dataset', + }, + { + 'path' :'~/user/prepped/dataset', + 'splits_path': 'splits/replicate-1.json' + }, + { + 'path' :'~/user/prepped/dataset', + 'params': {'window_size': 2000} + }, + { + 'name' : 'BioSoundSegBench', + 'path' :'~/user/prepped/dataset', + 'params': {'window_size': 2000}, + }, + ] + ) + def test_from_config_dict(self, config_dict): + dataset_config = vak.config.dataset.DatasetConfig.from_config_dict(config_dict) + assert isinstance(dataset_config, vak.config.dataset.DatasetConfig) + assert dataset_config.path == pathlib.Path(config_dict['path']) + if 'splits_path' in config_dict: + assert dataset_config.splits_path == pathlib.Path(config_dict['splits_path']) + else: + assert dataset_config.splits_path is None + if 'name' in config_dict: + assert dataset_config.name == config_dict['name'] + else: + assert dataset_config.name is None + if 'params' in config_dict: + assert dataset_config.params == config_dict['params'] + else: + assert dataset_config.params == {} + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'path' :'~/datasets/BioSoundSegBench', + 'splits_path': 'splits/Bengalese-Finch-song-gy6or6-replicate-1.json', + 'name': 'BioSoundSegBench', + }, + { + 'path' :'~/user/prepped/dataset', + }, + { + 'path' :'~/user/prepped/dataset', + 'splits_path': 'splits/replicate-1.json' + }, + { + 'path' :'~/user/prepped/dataset', + 'params': {'window_size': 2000} + }, + { + 'name' : 'BioSoundSegBench', + 'path' :'~/user/prepped/dataset', + 'params': {'window_size': 2000}, + }, + ] + ) + def test_asdict(self, config_dict): + dataset_config = vak.config.dataset.DatasetConfig.from_config_dict(config_dict) + + dataset_config_as_dict = dataset_config.asdict() + + assert isinstance(dataset_config_as_dict, dict) + for key in ('name', 'path', 'splits_path', 'params'): + if key in config_dict: + if 'path' in key: + assert dataset_config_as_dict[key] == pathlib.Path(config_dict[key]) + else: + assert dataset_config_as_dict[key] == config_dict[key] + else: + if key == 'params': + assert dataset_config_as_dict[key] == {} + else: + assert dataset_config_as_dict[key] is None diff --git a/tests/test_config/test_eval.py b/tests/test_config/test_eval.py index 9b0ce2793..de0ac681e 100644 --- a/tests/test_config/test_eval.py +++ b/tests/test_config/test_eval.py @@ -1,10 +1,227 @@ """tests for vak.config.eval module""" -import vak.config.eval +import pytest +import vak.config -def test_predict_attrs_class(all_generated_eval_configs_toml): - """test that instantiating EvalConfig class works as expected""" - for config_toml in all_generated_eval_configs_toml: - eval_section = config_toml["EVAL"] - config = vak.config.eval.EvalConfig(**eval_section) - assert isinstance(config, vak.config.eval.EvalConfig) + +class TestEval: + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + } + ] + ) + def test_init(self, config_dict): + config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) + config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + + eval_config = vak.config.EvalConfig(**config_dict) + + assert isinstance(eval_config, vak.config.EvalConfig) + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + } + ] + ) + def test_from_config_dict(self, config_dict): + eval_config = vak.config.EvalConfig.from_config_dict(config_dict) + + assert isinstance(eval_config, vak.config.EvalConfig) + + def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): + eval_table = a_generated_eval_config_dict["eval"] + + eval_config = vak.config.eval.EvalConfig.from_config_dict(eval_table) + + assert isinstance(eval_config, vak.config.eval.EvalConfig) + + @pytest.mark.parametrize( + 'config_dict, expected_exception', + [ + # missing 'model', should raise KeyError + ( + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + }, + KeyError + ), + # missing 'dataset', should raise KeyError + ( + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + }, + KeyError + ), + # missing 'checkpoint_path', should raise KeyError + ( + { + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + }, + KeyError + ), + # missing 'output_dir', should raise KeyError + ( + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + }, + KeyError + ) + ] + ) + def test_from_config_dict_raises(self, config_dict, expected_exception): + with pytest.raises(expected_exception): + vak.config.EvalConfig.from_config_dict(config_dict) \ No newline at end of file diff --git a/tests/test_config/test_learncurve.py b/tests/test_config/test_learncurve.py index 6fa147c75..6d2d65270 100644 --- a/tests/test_config/test_learncurve.py +++ b/tests/test_config/test_learncurve.py @@ -1,10 +1,204 @@ """tests for vak.config.learncurve module""" +import pytest + import vak.config.learncurve -def test_learncurve_attrs_class(all_generated_learncurve_configs_toml): - """test that instantiating LearncurveConfig class works as expected""" - for config_toml in all_generated_learncurve_configs_toml: - learncurve_section = config_toml["LEARNCURVE"] - config = vak.config.learncurve.LearncurveConfig(**learncurve_section) +class TestLearncurveConfig: + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + } + ] + ) + def test_init(self, config_dict): + config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) + config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + + learncurve_config = vak.config.LearncurveConfig(**config_dict) + + assert isinstance(learncurve_config, vak.config.LearncurveConfig) + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + } + ] + ) + def test_from_config_dict(self, config_dict): + learncurve_config = vak.config.LearncurveConfig.from_config_dict(config_dict) + + assert isinstance(learncurve_config, vak.config.LearncurveConfig) + + def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_dict): + """test that instantiating LearncurveConfig class works as expected""" + learncurve_table = a_generated_learncurve_config_dict["learncurve"] + + config = vak.config.learncurve.LearncurveConfig.from_config_dict( + learncurve_table + ) + assert isinstance(config, vak.config.learncurve.LearncurveConfig) + + @pytest.mark.parametrize( + 'config_dict, expected_exception', + [ + # missing 'model', should raise KeyError + ( + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + }, + KeyError + ), + # missing 'dataset', should raise KeyError + ( + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + }, + KeyError + ), + # missing 'root_results_dir', should raise KeyError + ( + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + }, + KeyError + ) + ] + ) + def test_from_config_dict_raises(self, config_dict, expected_exception): + with pytest.raises(expected_exception): + vak.config.LearncurveConfig.from_config_dict(config_dict) diff --git a/tests/test_config/test_load.py b/tests/test_config/test_load.py new file mode 100644 index 000000000..81c0e1809 --- /dev/null +++ b/tests/test_config/test_load.py @@ -0,0 +1,24 @@ +"""tests for vak.config.load module""" +import tomlkit + +import vak.config.load + + +def test__tomlkit_to_pop(a_generated_config_path): + with a_generated_config_path.open('r') as fp: + tomldoc = tomlkit.load(fp) + out = vak.config.load._tomlkit_to_popo(tomldoc) + assert isinstance(out, dict) + assert list(out.keys()) == ["vak"] + + +def test__load_from_toml_path(a_generated_config_path): + config_dict = vak.config.load._load_toml_from_path(a_generated_config_path) + + assert isinstance(config_dict, dict) + + with a_generated_config_path.open('r') as fp: + tomldoc = tomlkit.load(fp) + config_dict_raw = vak.config.load._tomlkit_to_popo(tomldoc) + + assert len(list(config_dict.keys())) == len(list(config_dict_raw["vak"].keys())) diff --git a/tests/test_config/test_model.py b/tests/test_config/test_model.py index 7f5e56c4b..3ba45fd48 100644 --- a/tests/test_config/test_model.py +++ b/tests/test_config/test_model.py @@ -1,62 +1,171 @@ -import copy import pytest -from ..fixtures import ( - ALL_GENERATED_CONFIGS_TOML, - ALL_GENERATED_CONFIGS_TOML_PATH_PAIRS -) - import vak.config.model -def _make_expected_config(model_config: dict) -> dict: - for attr in vak.config.model.MODEL_TABLES: - if attr not in model_config: - model_config[attr] = {} - return model_config - - -@pytest.mark.parametrize( - 'toml_dict', - ALL_GENERATED_CONFIGS_TOML -) -def test_config_from_toml_dict(toml_dict): - for section_name in ('TRAIN', 'EVAL', 'LEARNCURVE', 'PREDICT'): - try: - section = toml_dict[section_name] - except KeyError: - continue - model_name = section['model'] - # we need to copy so that we don't silently fail to detect mistakes - # by comparing a reference to the dict with itself - expected_model_config = copy.deepcopy( - toml_dict[model_name] +class TestModelConfig: + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'NonExistentModel': { + 'network': {}, + 'optimizer': {}, + 'loss': {}, + 'metrics': {}, + } + }, + { + 'TweetyNet': { + 'network': {}, + 'optimizer': {'lr': 1e-3}, + 'loss': {}, + 'metrics': {}, + } + }, + ] ) - expected_model_config = _make_expected_config(expected_model_config) - - model_config = vak.config.model.config_from_toml_dict(toml_dict, model_name) - - assert model_config == expected_model_config - - -@pytest.mark.parametrize( - 'toml_dict, toml_path', - ALL_GENERATED_CONFIGS_TOML_PATH_PAIRS -) -def test_config_from_toml_path(toml_dict, toml_path): - for section_name in ('TRAIN', 'EVAL', 'LEARNCURVE', 'PREDICT'): - try: - section = toml_dict[section_name] - except KeyError: - continue - model_name = section['model'] - # we need to copy so that we don't silently fail to detect mistakes - # by comparing a reference to the dict with itself - expected_model_config = copy.deepcopy( - toml_dict[model_name] + def test_init(self, config_dict): + name=list(config_dict.keys())[0] + config_dict_from_name = config_dict[name] + + model_config = vak.config.model.ModelConfig( + name=name, + **config_dict_from_name + ) + + assert isinstance(model_config, vak.config.model.ModelConfig) + assert model_config.name == name + for key, val in config_dict_from_name.items(): + assert hasattr(model_config, key) + assert getattr(model_config, key) == val + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'TweetyNet': { + 'optimizer': {'lr': 1e-3}, + } + }, + { + 'TweetyNet': { + 'network': {}, + 'optimizer': {'lr': 1e-3}, + 'loss': {}, + 'metrics': {}, + } + }, + { + 'ED_TCN': { + 'optimizer': {'lr': 1e-3}, + } + }, + { + "ConvEncoderUMAP": { + "optimizer": 1e-3 + } + } + ] + ) + def test_from_config_dict(self, config_dict): + model_config = vak.config.model.ModelConfig.from_config_dict(config_dict) + + name=list(config_dict.keys())[0] + config_dict_from_name = config_dict[name] + assert model_config.name == name + for attr in ('network', 'optimizer', 'loss', 'metrics'): + assert hasattr(model_config, attr) + if attr in config_dict_from_name: + assert getattr(model_config, attr) == config_dict_from_name[attr] + else: + assert getattr(model_config, attr) == {} + + def test_from_config_dict_real_config(self, a_generated_config_dict): + config_dict = None + for table_name in ('train', 'eval', 'predict', 'learncurve'): + if table_name in a_generated_config_dict: + config_dict = a_generated_config_dict[table_name]['model'] + if config_dict is None: + raise ValueError( + f"Didn't find top-level table for config: {a_generated_config_dict}" + ) + + model_config = vak.config.model.ModelConfig.from_config_dict(config_dict) + + name=list(config_dict.keys())[0] + config_dict_from_name = config_dict[name] + assert model_config.name == name + for attr in ('network', 'optimizer', 'loss', 'metrics'): + assert hasattr(model_config, attr) + if attr in config_dict_from_name: + assert getattr(model_config, attr) == config_dict_from_name[attr] + else: + assert getattr(model_config, attr) == {} + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'TweetyNet': { + 'optimizer': {'lr': 1e-3}, + } + }, + { + 'TweetyNet': { + 'network': {}, + 'optimizer': {'lr': 1e-3}, + 'loss': {}, + 'metrics': {}, + } + }, + { + 'ED_TCN': { + 'optimizer': {'lr': 1e-3}, + } + }, + { + "ConvEncoderUMAP": { + "optimizer": 1e-3 + } + } + ] ) - expected_model_config = _make_expected_config(expected_model_config) + def test_asdict(self, config_dict): + model_config = vak.config.model.ModelConfig.from_config_dict(config_dict) - model_config = vak.config.model.config_from_toml_path(toml_path, model_name) + model_config_as_dict = model_config.asdict() - assert model_config == expected_model_config + assert isinstance(model_config_as_dict, dict) + + model_name = list(config_dict.keys())[0] + for key in ('name', 'network', 'optimizer', 'loss', 'metrics'): + if key == 'name': + assert model_config_as_dict[key] == model_name + else: + if key in config_dict[model_name]: + assert model_config_as_dict[key] == config_dict[model_name][key] + else: + assert model_config_as_dict[key] == {} + + @pytest.mark.parametrize( + 'config_dict, expected_exception', + [ + # Raises ValueError because model name not in registry + ( + { + 'NonExistentModel': { + 'network': {}, + 'optimizer': {}, + 'loss': {}, + 'metrics': {}, + } + }, + ValueError, + ) + ] + ) + def test_from_config_dict_raises(self, config_dict, expected_exception): + with pytest.raises(expected_exception): + vak.config.model.ModelConfig.from_config_dict(config_dict) diff --git a/tests/test_config/test_parse.py b/tests/test_config/test_parse.py deleted file mode 100644 index 70549b34f..000000000 --- a/tests/test_config/test_parse.py +++ /dev/null @@ -1,243 +0,0 @@ -"""tests for vak.config.parse module""" -import copy - -import pytest - -import vak.config -import vak.transforms.transforms -import vak.models - - -@pytest.mark.parametrize( - "section_name", - [ - "DATALOADER", - "EVAL" "LEARNCURVE", - "PREDICT", - "PREP", - "SPECT_PARAMS", - "TRAIN", - ], -) -def test_parse_config_section_returns_attrs_class( - section_name, - configs_toml_path_pairs_by_model_factory, -): - """test that ``vak.config.parse.parse_config_section`` - returns an instance of ``vak.config.learncurve.LearncurveConfig``""" - config_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet", section_name) - for config_toml, toml_path in config_toml_path_pairs: - config_section_obj = vak.config.parse.parse_config_section( - config_toml=config_toml, - section_name=section_name, - toml_path=toml_path, - ) - assert isinstance( - config_section_obj, vak.config.parse.SECTION_CLASSES[section_name] - ) - - -@pytest.mark.parametrize( - "section_name", - [ - "EVAL", - "LEARNCURVE", - "PREDICT", - "PREP", - "SPECT_PARAMS", - "TRAIN", - ], -) -def test_parse_config_section_missing_options_raises( - section_name, - configs_toml_path_pairs_by_model_factory, -): - """test that configs without the required options in a section raise KeyError""" - if vak.config.parse.REQUIRED_OPTIONS[section_name] is None: - pytest.skip(f"no required options to test for section: {section_name}") - - configs_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet", section_name) - - for config_toml, toml_path in configs_toml_path_pairs: - for option in vak.config.parse.REQUIRED_OPTIONS[section_name]: - config_copy = copy.deepcopy(config_toml) - config_copy[section_name].pop(option) - with pytest.raises(KeyError): - config = vak.config.parse.parse_config_section( - config_toml=config_copy, - section_name=section_name, - toml_path=toml_path, - ) - - -@pytest.mark.parametrize("section_name", ["EVAL", "LEARNCURVE", "PREDICT", "TRAIN"]) -def test_parse_config_section_model_not_installed_raises( - section_name, - configs_toml_path_pairs_by_model_factory, -): - """test that a ValueError is raised when the ``models`` option - in the section specifies names of models that are not installed""" - # we only need one toml, path pair - # so we just call next on the ``zipped`` iterator that our fixture gives us - configs_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet") - - for config_toml, toml_path in configs_toml_path_pairs: - if section_name.lower() in toml_path.name: - break # use these. Only need to test on one - - config_toml[section_name]["model"] = "NotInstalledNet" - with pytest.raises(ValueError): - vak.config.parse.parse_config_section( - config_toml=config_toml, section_name=section_name, toml_path=toml_path - ) - - -def test_parse_prep_section_both_audio_and_spect_format_raises( - all_generated_configs_toml_path_pairs, -): - """test that a config with both an audio and a spect format raises a ValueError""" - # iterate through configs til we find one with an `audio_format` option - # and then we'll add a `spect_format` option to it - found_config_to_use = False - for config_toml, toml_path in all_generated_configs_toml_path_pairs: - if "audio_format" in config_toml["PREP"]: - found_config_to_use = True - break - assert found_config_to_use # sanity check - - config_toml["PREP"]["spect_format"] = "mat" - with pytest.raises(ValueError): - vak.config.parse.parse_config_section(config_toml, "PREP", toml_path) - - -def test_parse_prep_section_neither_audio_nor_spect_format_raises( - all_generated_configs_toml_path_pairs, -): - """test that a config without either an audio or a spect format raises a ValueError""" - # iterate through configs til we find one with an `audio_format` option - # and then we'll add a `spect_format` option to it - found_config_to_use = False - for config_toml, toml_path in all_generated_configs_toml_path_pairs: - if "audio_format" in config_toml["PREP"]: - found_config_to_use = True - break - assert found_config_to_use # sanity check - - config_toml["PREP"].pop("audio_format") - if "spect_format" in config_toml["PREP"]: - # shouldn't be but humor me - config_toml["PREP"].pop("spect_format") - - with pytest.raises(ValueError): - vak.config.parse.parse_config_section(config_toml, "PREP", toml_path) - - -def test_load_from_toml_path(all_generated_configs): - for toml_path in all_generated_configs: - config_toml = vak.config.parse._load_toml_from_path(toml_path) - assert isinstance(config_toml, dict) - - -def test_load_from_toml_path_raises_when_config_doesnt_exist(config_that_doesnt_exist): - with pytest.raises(FileNotFoundError): - vak.config.parse._load_toml_from_path(config_that_doesnt_exist) - - -def test_from_toml_path_returns_instance_of_config( - all_generated_configs, default_model -): - for toml_path in all_generated_configs: - if default_model not in str(toml_path): - continue # only need to check configs for one model - # also avoids FileNotFoundError on CI - config_obj = vak.config.parse.from_toml_path(toml_path) - assert isinstance(config_obj, vak.config.parse.Config) - - -def test_from_toml_path_raises_when_config_doesnt_exist(config_that_doesnt_exist): - with pytest.raises(FileNotFoundError): - vak.config.parse.from_toml_path(config_that_doesnt_exist) - - -def test_from_toml(configs_toml_path_pairs_by_model_factory): - config_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet") - for config_toml, toml_path in config_toml_path_pairs: - config_obj = vak.config.parse.from_toml(config_toml, toml_path) - assert isinstance(config_obj, vak.config.parse.Config) - - -def test_from_toml_parse_prep_with_sections_not_none( - configs_toml_path_pairs_by_model_factory, -): - """test that we get only the sections we want when we pass in a sections list to - ``from_toml``. Specifically test ``PREP`` since that's what this will be used for.""" - # only use configs from 'default_model') (TeenyTweetyNet) - # so we are sure paths exist, to avoid NotADirectoryErrors that give spurious test failures - config_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet") - for config_toml, toml_path in config_toml_path_pairs: - config_obj = vak.config.parse.from_toml( - config_toml, toml_path, sections=["PREP", "SPECT_PARAMS"] - ) - assert isinstance(config_obj, vak.config.parse.Config) - for should_have in ("prep", "spect_params"): - assert hasattr(config_obj, should_have) - for should_be_none in ("eval", "learncurve", "train", "predict"): - assert getattr(config_obj, should_be_none) is None - assert ( - getattr(config_obj, "dataloader") - == vak.config.dataloader.DataLoaderConfig() - ) - - -@pytest.mark.parametrize("section_name", ["EVAL", "LEARNCURVE", "PREDICT", "TRAIN"]) -def test_from_toml_parse_prep_with_sections_not_none( - section_name, all_generated_configs_toml_path_pairs, random_path_factory -): - """Test that ``config.parse.from_toml`` parameter ``sections`` works as expected. - - If we pass in a list of section names - specifying that we only want to parse ``PREP`` and ``SPECT_PARAMS``, - other sections should be left as None in the return Config instance.""" - for config_toml, toml_path in all_generated_configs_toml_path_pairs: - if section_name.lower() in toml_path.name: - break # use these - - purpose = vak.cli.prep.purpose_from_toml(config_toml, toml_path) - section_name = purpose.upper() - required_options = vak.config.parse.REQUIRED_OPTIONS[section_name] - for required_option in required_options: - # set option to values that **would** cause an error if we parse them - if "path" in required_option: - badval = random_path_factory(f"_{required_option}.exe") - elif "dir" in required_option: - badval = random_path_factory("nonexistent_dir") - else: - continue - config_toml[section_name][required_option] = badval - cfg = vak.config.parse.from_toml( - config_toml, toml_path, sections=["PREP", "SPECT_PARAMS"] - ) - assert hasattr(cfg, 'prep') and getattr(cfg, 'prep') is not None - assert hasattr(cfg, 'spect_params') and getattr(cfg, 'spect_params') is not None - assert getattr(cfg, purpose) is None - - -def test_invalid_section_raises(invalid_section_config_path): - with pytest.raises(ValueError): - vak.config.parse.from_toml_path(invalid_section_config_path) - - -def test_invalid_option_raises(invalid_option_config_path): - with pytest.raises(ValueError): - vak.config.parse.from_toml_path(invalid_option_config_path) - - -@pytest.fixture -def invalid_train_and_learncurve_config_toml(test_configs_root): - return test_configs_root.joinpath("invalid_train_and_learncurve_config.toml") - - -def test_train_and_learncurve_defined_raises(invalid_train_and_learncurve_config_toml): - """test that a .toml config with both a TRAIN and a LEARNCURVE section raises a ValueError""" - with pytest.raises(ValueError): - vak.config.parse.from_toml_path(invalid_train_and_learncurve_config_toml) diff --git a/tests/test_config/test_predict.py b/tests/test_config/test_predict.py index c6fcb22f3..8d81dcf07 100644 --- a/tests/test_config/test_predict.py +++ b/tests/test_config/test_predict.py @@ -1,10 +1,183 @@ """tests for vak.config.predict module""" +import pytest + import vak.config.predict -def test_predict_attrs_class(all_generated_predict_configs_toml): - """test that instantiating PredictConfig class works as expected""" - for config_toml in all_generated_predict_configs_toml: - predict_section = config_toml["PREDICT"] - config = vak.config.predict.PredictConfig(**predict_section) - assert isinstance(config, vak.config.predict.PredictConfig) +class TestPredictConfig: + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'spect_scaler_path': '/home/user/results_181014_194418/spect_scaler', + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', + 'annot_csv_filename': 'bl26lb16.041912.annot.csv', + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + } + ] + ) + def test_init(self, config_dict): + config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) + config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + + predict_config = vak.config.PredictConfig(**config_dict) + + assert isinstance(predict_config, vak.config.PredictConfig) + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'spect_scaler_path': '/home/user/results_181014_194418/spect_scaler', + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', + 'annot_csv_filename': 'bl26lb16.041912.annot.csv', + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + } + ] + ) + def test_from_config_dict(self, config_dict): + predict_config = vak.config.PredictConfig.from_config_dict(config_dict) + + assert isinstance(predict_config, vak.config.PredictConfig) + + def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict): + predict_table = a_generated_predict_config_dict["predict"] + + predict_config = vak.config.predict.PredictConfig.from_config_dict(predict_table) + + assert isinstance(predict_config, vak.config.predict.PredictConfig) + + @pytest.mark.parametrize( + 'config_dict, expected_exception', + [ + # missing 'checkpoint_path', should raise KeyError + ( + { + 'spect_scaler_path': '/home/user/results_181014_194418/spect_scaler', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', + 'annot_csv_filename': 'bl26lb16.041912.annot.csv', + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + }, + KeyError + ), + # missing 'dataset', should raise KeyError + ( + { + 'spect_scaler_path': '/home/user/results_181014_194418/spect_scaler', + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', + 'annot_csv_filename': 'bl26lb16.041912.annot.csv', + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + }, + KeyError + ), + # missing 'model', should raise KeyError + ( + { + 'spect_scaler_path': '/home/user/results_181014_194418/spect_scaler', + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', + 'annot_csv_filename': 'bl26lb16.041912.annot.csv', + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + }, + KeyError + ), + ] + ) + def test_from_config_dict_raises(self, config_dict, expected_exception): + with pytest.raises(expected_exception): + vak.config.PredictConfig.from_config_dict(config_dict) diff --git a/tests/test_config/test_prep.py b/tests/test_config/test_prep.py index 3912f11f0..9583b708b 100644 --- a/tests/test_config/test_prep.py +++ b/tests/test_config/test_prep.py @@ -1,12 +1,173 @@ """tests for vak.config.prep module""" +import copy + +import pytest + import vak.config.prep -def test_parse_prep_config_returns_PrepConfig_instance( - configs_toml_path_pairs_by_model_factory, -): - config_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet") - for config_toml, toml_path in config_toml_path_pairs: - prep_section = config_toml["PREP"] - config = vak.config.prep.PrepConfig(**prep_section) - assert isinstance(config, vak.config.prep.PrepConfig) +class TestPrepConfig: + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'annot_format': 'notmat', + 'audio_format': 'cbin', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412', + 'dataset_type': 'parametric umap', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/eval/audio_cbin_annot_notmat/ConvEncoderUMAP', + 'spect_params': {'fft_size': 512, + 'step_size': 32, + 'transform_type': 'log_spect_plus_one'}, + 'test_dur': 0.2 + }, + { + 'annot_format': 'notmat', + 'audio_format': 'cbin', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312', + 'dataset_type': 'frame classification', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet', + 'spect_params': {'fft_size': 512, + 'freq_cutoffs': [500, 10000], + 'step_size': 64, + 'thresh': 6.25, + 'transform_type': 'log_spect'}, + 'test_dur': 30, + 'train_dur': 50, + 'val_dur': 15 + }, + ] + ) + def test_init(self, config_dict): + config_dict['spect_params'] = vak.config.SpectParamsConfig(**config_dict['spect_params']) + + prep_config = vak.config.PrepConfig(**config_dict) + + assert isinstance(prep_config, vak.config.prep.PrepConfig) + for key, val in config_dict.items(): + assert hasattr(prep_config, key) + if key == 'data_dir' or key == 'output_dir': + assert getattr(prep_config, key) == vak.common.converters.expanded_user_path(val) + elif key == 'labelset': + assert getattr(prep_config, key) == vak.common.converters.labelset_to_set(val) + else: + assert getattr(prep_config, key) == val + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'annot_format': 'notmat', + 'audio_format': 'cbin', + 'spect_format': 'mat', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412', + 'dataset_type': 'parametric umap', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/eval/audio_cbin_annot_notmat/ConvEncoderUMAP', + 'spect_params': {'fft_size': 512, + 'step_size': 32, + 'transform_type': 'log_spect_plus_one'}, + 'test_dur': 0.2 + }, + ] + ) + def test_both_audio_and_spect_format_raises( + self, config_dict, + ): + """test that a config with both an audio and a spect format raises a ValueError""" + # need to do this set-up so we don't mask one error with another + config_dict['spect_params'] = vak.config.SpectParamsConfig(**config_dict['spect_params']) + + with pytest.raises(ValueError): + prep_config = vak.config.PrepConfig(**config_dict) + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'annot_format': 'notmat', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412', + 'dataset_type': 'parametric umap', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/eval/audio_cbin_annot_notmat/ConvEncoderUMAP', + 'spect_params': {'fft_size': 512, + 'step_size': 32, + 'transform_type': 'log_spect_plus_one'}, + 'test_dur': 0.2 + }, + ] + ) + def test_neither_audio_nor_spect_format_raises( + self, config_dict + ): + """test that a config without either an audio or a spect format raises a ValueError""" + # need to do this set-up so we don't mask one error with another + config_dict['spect_params'] = vak.config.SpectParamsConfig(**config_dict['spect_params']) + + with pytest.raises(ValueError): + prep_config = vak.config.PrepConfig(**config_dict) + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'annot_format': 'notmat', + 'audio_format': 'cbin', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412', + 'dataset_type': 'parametric umap', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/eval/audio_cbin_annot_notmat/ConvEncoderUMAP', + 'spect_params': {'fft_size': 512, + 'step_size': 32, + 'transform_type': 'log_spect_plus_one'}, + 'test_dur': 0.2 + }, + { + 'annot_format': 'notmat', + 'audio_format': 'cbin', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312', + 'dataset_type': 'frame classification', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet', + 'spect_params': {'fft_size': 512, + 'freq_cutoffs': [500, 10000], + 'step_size': 64, + 'thresh': 6.25, + 'transform_type': 'log_spect'}, + 'test_dur': 30, + 'train_dur': 50, + 'val_dur': 15 + }, + ] + ) + def test_from_config_dict(self, config_dict): + # we have to make a copy since `from_config_dict` mutates the dict + config_dict_copy = copy.deepcopy(config_dict) + + prep_config = vak.config.prep.PrepConfig.from_config_dict(config_dict_copy) + + assert isinstance(prep_config, vak.config.prep.PrepConfig) + for key, val in config_dict.items(): + assert hasattr(prep_config, key) + if key == 'data_dir' or key == 'output_dir': + assert getattr(prep_config, key) == vak.common.converters.expanded_user_path(val) + elif key == 'labelset': + assert getattr(prep_config, key) == vak.common.converters.labelset_to_set(val) + elif key == 'spect_params': + assert getattr(prep_config, key) == vak.config.SpectParamsConfig(**val) + else: + assert getattr(prep_config, key) == val + + def test_from_config_dict_real_config( + self, a_generated_config_dict + ): + prep_config = vak.config.prep.PrepConfig.from_config_dict(a_generated_config_dict['prep']) + assert isinstance(prep_config, vak.config.prep.PrepConfig) diff --git a/tests/test_config/test_spect_params.py b/tests/test_config/test_spect_params.py index 10f0fcdeb..5accbb0f1 100644 --- a/tests/test_config/test_spect_params.py +++ b/tests/test_config/test_spect_params.py @@ -45,10 +45,21 @@ def test_freq_cutoffs_wrong_order_raises(): ) -def test_spect_params_attrs_class(all_generated_configs_toml_path_pairs): - """test that instantiating SpectParamsConfig class works as expected""" - for config_toml, toml_path in all_generated_configs_toml_path_pairs: - if "SPECT_PARAMS" in config_toml: - spect_params_section = config_toml["SPECT_PARAMS"] - config = vak.config.spect_params.SpectParamsConfig(**spect_params_section) - assert isinstance(config, vak.config.spect_params.SpectParamsConfig) +class TestSpectParamsConfig: + @pytest.mark.parametrize( + 'config_dict', + [ + {'fft_size': 512, 'step_size': 64, 'freq_cutoffs': [500, 10000], 'thresh': 6.25, 'transform_type': 'log_spect'}, + ] + ) + def test_init(self, config_dict): + spect_params_config = vak.config.SpectParamsConfig(**config_dict) + assert isinstance(spect_params_config, vak.config.spect_params.SpectParamsConfig) + + def test_with_real_config(self, a_generated_config_dict): + if "spect_params" in a_generated_config_dict['prep']: + spect_config_dict = a_generated_config_dict['prep']['spect_params'] + else: + pytest.skip("No spect params in config") + spect_params_config = vak.config.spect_params.SpectParamsConfig(**spect_config_dict) + assert isinstance(spect_params_config, vak.config.spect_params.SpectParamsConfig) diff --git a/tests/test_config/test_train.py b/tests/test_config/test_train.py index 1cde1db3e..e5a3127da 100644 --- a/tests/test_config/test_train.py +++ b/tests/test_config/test_train.py @@ -1,10 +1,159 @@ """tests for vak.config.train module""" +import pytest + import vak.config.train -def test_train_attrs_class(all_generated_train_configs_toml_path_pairs): - """test that instantiating TrainConfig class works as expected""" - for config_toml, toml_path in all_generated_train_configs_toml_path_pairs: - train_section = config_toml["TRAIN"] - train_config_obj = vak.config.train.TrainConfig(**train_section) - assert isinstance(train_config_obj, vak.config.train.TrainConfig) +class TestTrainConfig: + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + } + ] + ) + def test_init(self, config_dict): + config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) + config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + + train_config = vak.config.TrainConfig(**config_dict) + + assert isinstance(train_config, vak.config.TrainConfig) + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + } + ] + ) + def test_from_config_dict(self, config_dict): + train_config = vak.config.TrainConfig.from_config_dict(config_dict) + + assert isinstance(train_config, vak.config.TrainConfig) + + def test_from_config_dict_with_real_config(self, a_generated_train_config_dict): + train_table = a_generated_train_config_dict["train"] + + train_config = vak.config.train.TrainConfig.from_config_dict(train_table) + + assert isinstance(train_config, vak.config.train.TrainConfig) + + @pytest.mark.parametrize( + 'config_dict, expected_exception', + [ + ( + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + }, + KeyError + ), + ( + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + }, + KeyError + ) + ] + ) + def test_from_config_dict_raises(self, config_dict, expected_exception): + with pytest.raises(expected_exception): + vak.config.TrainConfig.from_config_dict(config_dict) diff --git a/tests/test_config/test_validators.py b/tests/test_config/test_validators.py index 49ae7fad9..9d0b14421 100644 --- a/tests/test_config/test_validators.py +++ b/tests/test_config/test_validators.py @@ -1,25 +1,22 @@ import pytest -import toml import vak.config.validators -def test_are_sections_valid(invalid_section_config_path): - """test that invalid section name raises a ValueError""" - with invalid_section_config_path.open("r") as fp: - config_toml = toml.load(fp) +def test_are_tables_valid(invalid_table_config_path): + """test that invalid table name raises a ValueError""" + config_dict = vak.config.load._load_toml_from_path(invalid_table_config_path) with pytest.raises(ValueError): - vak.config.validators.are_sections_valid( - config_toml, invalid_section_config_path + vak.config.validators.are_tables_valid( + config_dict, invalid_table_config_path ) -def test_are_options_valid(invalid_option_config_path): - """test that section with an invalid option name raises a ValueError""" - section_with_invalid_option = "PREP" - with invalid_option_config_path.open("r") as fp: - config_toml = toml.load(fp) +def test_are_keys_valid(invalid_key_config_path): + """test that table with an invalid key name raises a ValueError""" + table_with_invalid_key = "prep" + config_dict = vak.config.load._load_toml_from_path(invalid_key_config_path) with pytest.raises(ValueError): - vak.config.validators.are_options_valid( - config_toml, section_with_invalid_option, invalid_option_config_path + vak.config.validators.are_keys_valid( + config_dict, table_with_invalid_key, invalid_key_config_path ) diff --git a/tests/test_datasets/test_frame_classification/test_frames_dataset.py b/tests/test_datasets/test_frame_classification/test_frames_dataset.py index a7674ec61..f71c7f9fb 100644 --- a/tests/test_datasets/test_frame_classification/test_frames_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_frames_dataset.py @@ -19,15 +19,18 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo audio_format=audio_format, spect_format=spect_format, annot_format=annot_format) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) cfg_command = getattr(cfg, config_type) + transform_kwargs = { + "window_size": cfg.eval.dataset.params["window_size"] + } item_transform = vak.transforms.defaults.get_default_transform( - model_name, config_type, cfg.eval.transform_params + model_name, config_type, transform_kwargs ) dataset = vak.datasets.frame_classification.FramesDataset.from_dataset_path( - dataset_path=cfg_command.dataset_path, + dataset_path=cfg_command.dataset.path, split=split, item_transform=item_transform, ) diff --git a/tests/test_datasets/test_frame_classification/test_window_dataset.py b/tests/test_datasets/test_frame_classification/test_window_dataset.py index 613fd1854..430917d2a 100644 --- a/tests/test_datasets/test_frame_classification/test_window_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_window_dataset.py @@ -20,18 +20,17 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo audio_format=audio_format, spect_format=spect_format, annot_format=annot_format) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) cfg_command = getattr(cfg, config_type) - transform, target_transform = vak.transforms.defaults.get_default_transform( + transform = vak.transforms.defaults.get_default_transform( model_name, config_type, transform_kwargs ) dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( - dataset_path=cfg_command.dataset_path, + dataset_path=cfg_command.dataset.path, split=split, - window_size=cfg_command.train_dataset_params['window_size'], - transform=transform, - target_transform=target_transform, + window_size=cfg_command.dataset.params['window_size'], + item_transform=transform, ) assert isinstance(dataset, vak.datasets.frame_classification.WindowDataset) diff --git a/tests/test_datasets/test_parametric_umap/test_parametric_umap.py b/tests/test_datasets/test_parametric_umap/test_parametric_umap.py index 15eab713f..7f2c0bb38 100644 --- a/tests/test_datasets/test_parametric_umap/test_parametric_umap.py +++ b/tests/test_datasets/test_parametric_umap/test_parametric_umap.py @@ -19,7 +19,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo audio_format=audio_format, spect_format=spect_format, annot_format=annot_format) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) cfg_command = getattr(cfg, config_type) transform = vak.transforms.defaults.get_default_transform( @@ -27,7 +27,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo ) dataset = vak.datasets.parametric_umap.ParametricUMAPDataset.from_dataset_path( - dataset_path=cfg_command.dataset_path, + dataset_path=cfg_command.dataset.path, split=split, transform=transform, ) diff --git a/tests/test_eval/test_eval.py b/tests/test_eval/test_eval.py index b4e69322b..7a874afb9 100644 --- a/tests/test_eval/test_eval.py +++ b/tests/test_eval/test_eval.py @@ -29,9 +29,9 @@ def test_eval( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": 'cpu'}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": 'cpu'}, ] toml_path = specific_config_toml_path( @@ -40,26 +40,22 @@ def test_eval( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) + cfg = vak.config.Config.from_toml_path(toml_path) results_path = tmp_path / 'results_path' results_path.mkdir() with mock.patch(eval_function_to_mock, autospec=True) as mock_eval_function: vak.eval.eval( - model_name=model_name, - model_config=model_config, - dataset_path=cfg.eval.dataset_path, + model_config=cfg.eval.model.asdict(), + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, batch_size=cfg.eval.batch_size, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, spect_scaler_path=cfg.eval.spect_scaler_path, device=cfg.eval.device, post_tfm_kwargs=cfg.eval.post_tfm_kwargs, diff --git a/tests/test_eval/test_frame_classification.py b/tests/test_eval/test_frame_classification.py index ce299c0e6..ee55825a5 100644 --- a/tests/test_eval/test_frame_classification.py +++ b/tests/test_eval/test_frame_classification.py @@ -54,9 +54,9 @@ def test_eval_frame_classification_model( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": device}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -65,35 +65,31 @@ def test_eval_frame_classification_model( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) + cfg = vak.config.Config.from_toml_path(toml_path) vak.eval.frame_classification.eval_frame_classification_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset_path, + model_config=cfg.eval.model.asdict(), + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, spect_scaler_path=cfg.eval.spect_scaler_path, device=cfg.eval.device, post_tfm_kwargs=post_tfm_kwargs, ) - assert_eval_output_matches_expected(cfg.eval.model, output_dir) + assert_eval_output_matches_expected(cfg.eval.model.name, output_dir) @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "EVAL", "option": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, - {"section": "EVAL", "option": "labelmap_path", "value": '/obviously/doesnt/exist/labelmap.json'}, - {"section": "EVAL", "option": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, + {"table": "eval", "key": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, + {"table": "eval", "key": "labelmap_path", "value": '/obviously/doesnt/exist/labelmap.json'}, + {"table": "eval", "key": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, ] ) def test_eval_frame_classification_model_raises_file_not_found( @@ -111,9 +107,9 @@ def test_eval_frame_classification_model_raises_file_not_found( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": device}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": device}, path_option_to_change, ] @@ -123,21 +119,17 @@ def test_eval_frame_classification_model_raises_file_not_found( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) + cfg = vak.config.Config.from_toml_path(toml_path) with pytest.raises(FileNotFoundError): vak.eval.frame_classification.eval_frame_classification_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset_path, + model_config=cfg.eval.model.asdict(), + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, spect_scaler_path=cfg.eval.spect_scaler_path, device=cfg.eval.device, ) @@ -146,8 +138,8 @@ def test_eval_frame_classification_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "EVAL", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, - {"section": "EVAL", "option": "output_dir", "value": '/obviously/does/not/exist/output'}, + {"table": "eval", "key": ["dataset","path"], "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "eval", "key": "output_dir", "value": '/obviously/does/not/exist/output'}, ] ) def test_eval_frame_classification_model_raises_not_a_directory( @@ -159,20 +151,20 @@ def test_eval_frame_classification_model_raises_not_a_directory( """Test that core.eval raises NotADirectory when directories don't exist """ - options_to_change = [ + keys_to_change = [ path_option_to_change, - {"section": "EVAL", "option": "device", "value": device}, + {"table": "eval", "key": "device", "value": device}, ] - if path_option_to_change["option"] != "output_dir": + if path_option_to_change["key"] != "output_dir": # need to make sure output_dir *does* exist # so we don't detect spurious NotADirectoryError and assume test passes output_dir = tmp_path.joinpath( f"test_eval_raises_not_a_directory" ) output_dir.mkdir() - options_to_change.append( - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)} + keys_to_change.append( + {"table": "eval", "key": "output_dir", "value": str(output_dir)} ) toml_path = specific_config_toml_path( @@ -181,21 +173,17 @@ def test_eval_frame_classification_model_raises_not_a_directory( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) + cfg = vak.config.Config.from_toml_path(toml_path) with pytest.raises(NotADirectoryError): vak.eval.frame_classification.eval_frame_classification_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset_path, + model_config=cfg.eval.model.asdict(), + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, spect_scaler_path=cfg.eval.spect_scaler_path, device=cfg.eval.device, ) diff --git a/tests/test_eval/test_parametric_umap.py b/tests/test_eval/test_parametric_umap.py index 5b803a7e7..4c6c7e573 100644 --- a/tests/test_eval/test_parametric_umap.py +++ b/tests/test_eval/test_parametric_umap.py @@ -32,9 +32,9 @@ def test_eval_parametric_umap_model( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": device}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -43,31 +43,27 @@ def test_eval_parametric_umap_model( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) + cfg = vak.config.Config.from_toml_path(toml_path) vak.eval.parametric_umap.eval_parametric_umap_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset_path, + model_config=cfg.eval.model.asdict(), + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, device=cfg.eval.device, ) - assert_eval_output_matches_expected(cfg.eval.model, output_dir) + assert_eval_output_matches_expected(cfg.eval.model.name, output_dir) @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "EVAL", "option": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, + {"table": "eval", "key": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, ] ) def test_eval_frame_classification_model_raises_file_not_found( @@ -83,9 +79,9 @@ def test_eval_frame_classification_model_raises_file_not_found( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": device}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": device}, path_option_to_change, ] @@ -95,21 +91,17 @@ def test_eval_frame_classification_model_raises_file_not_found( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) + cfg = vak.config.Config.from_toml_path(toml_path) with pytest.raises(FileNotFoundError): vak.eval.parametric_umap.eval_parametric_umap_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset_path, + model_config=cfg.eval.model.asdict(), + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, device=cfg.eval.device, ) @@ -117,8 +109,8 @@ def test_eval_frame_classification_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "EVAL", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, - {"section": "EVAL", "option": "output_dir", "value": '/obviously/does/not/exist/output'}, + {"table": "eval", "key": ["dataset", "path"], "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "eval", "key": "output_dir", "value": '/obviously/does/not/exist/output'}, ] ) def test_eval_frame_classification_model_raises_not_a_directory( @@ -129,20 +121,20 @@ def test_eval_frame_classification_model_raises_not_a_directory( ): """Test that :func:`vak.eval.parametric_umap.eval_parametric_umap_model` raises NotADirectoryError when expected""" - options_to_change = [ + keys_to_change = [ path_option_to_change, - {"section": "EVAL", "option": "device", "value": device}, + {"table": "eval", "key": "device", "value": device}, ] - if path_option_to_change["option"] != "output_dir": + if path_option_to_change["key"] != "output_dir": # need to make sure output_dir *does* exist # so we don't detect spurious NotADirectoryError and assume test passes output_dir = tmp_path.joinpath( f"test_eval_raises_not_a_directory" ) output_dir.mkdir() - options_to_change.append( - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)} + keys_to_change.append( + {"table": "eval", "key": "output_dir", "value": str(output_dir)} ) toml_path = specific_config_toml_path( @@ -151,21 +143,16 @@ def test_eval_frame_classification_model_raises_not_a_directory( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) + cfg = vak.config.Config.from_toml_path(toml_path) with pytest.raises(NotADirectoryError): vak.eval.parametric_umap.eval_parametric_umap_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset_path, + model_config=cfg.eval.model.asdict(), + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, device=cfg.eval.device, ) - diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py index cc3484279..c88f4ebf4 100644 --- a/tests/test_learncurve/test_frame_classification.py +++ b/tests/test_learncurve/test_frame_classification.py @@ -52,32 +52,26 @@ def assert_learncurve_output_matches_expected(cfg, model_name, results_path): ) def test_learning_curve_for_frame_classification_model( model_name, audio_format, annot_format, specific_config_toml_path, tmp_path, device): - options_to_change = {"section": "LEARNCURVE", "option": "device", "value": device} + keys_to_change = {"table": "learncurve", "key": "device", "value": device} toml_path = specific_config_toml_path( config_type="learncurve", model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) + cfg = vak.config.Config.from_toml_path(toml_path) results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( - model_name=cfg.learncurve.model, - model_config=model_config, - dataset_path=cfg.learncurve.dataset_path, + model_config=cfg.learncurve.model.asdict(), + dataset_config=cfg.learncurve.dataset.asdict(), batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, - train_transform_params=cfg.learncurve.train_transform_params, - train_dataset_params=cfg.learncurve.train_dataset_params, - val_transform_params=cfg.learncurve.val_transform_params, - val_dataset_params=cfg.learncurve.val_dataset_params, results_path=results_path, post_tfm_kwargs=cfg.learncurve.post_tfm_kwargs, normalize_spectrograms=cfg.learncurve.normalize_spectrograms, @@ -88,14 +82,14 @@ def test_learning_curve_for_frame_classification_model( device=cfg.learncurve.device, ) - assert_learncurve_output_matches_expected(cfg, cfg.learncurve.model, results_path) + assert_learncurve_output_matches_expected(cfg, cfg.learncurve.model.name, results_path) @pytest.mark.parametrize( 'dir_option_to_change', [ - {"section": "LEARNCURVE", "option": "root_results_dir", "value": '/obviously/does/not/exist/results/'}, - {"section": "LEARNCURVE", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "learncurve", "key": "root_results_dir", "value": '/obviously/does/not/exist/results/'}, + {"table": "learncurve", "key": ["dataset", "path"], "value": '/obviously/doesnt/exist/dataset-dir'}, ] ) def test_learncurve_raises_not_a_directory(dir_option_to_change, @@ -105,8 +99,8 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, when the following directories do not exist: results_path, previous_run_path, dataset_path """ - options_to_change = [ - {"section": "LEARNCURVE", "option": "device", "value": device}, + keys_to_change = [ + {"table": "learncurve", "key": "device", "value": device}, dir_option_to_change ] toml_path = specific_config_toml_path( @@ -114,25 +108,19 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, model="TweetyNet", audio_format="cbin", annot_format="notmat", - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) + cfg = vak.config.Config.from_toml_path(toml_path) # mock behavior of cli.learncurve, building `results_path` from config option `root_results_dir` results_path = cfg.learncurve.root_results_dir / 'results-dir-timestamp' with pytest.raises(NotADirectoryError): vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( - model_name=cfg.learncurve.model, - model_config=model_config, - dataset_path=cfg.learncurve.dataset_path, + model_config=cfg.learncurve.model.asdict(), + dataset_config=cfg.learncurve.dataset.asdict(), batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, - train_transform_params=cfg.learncurve.train_transform_params, - train_dataset_params=cfg.learncurve.train_dataset_params, - val_transform_params=cfg.learncurve.val_transform_params, - val_dataset_params=cfg.learncurve.val_dataset_params, results_path=results_path, post_tfm_kwargs=cfg.learncurve.post_tfm_kwargs, normalize_spectrograms=cfg.learncurve.normalize_spectrograms, diff --git a/tests/test_metrics/test_segmentation.py b/tests/test_metrics/test_segmentation.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_models/test_base.py b/tests/test_models/test_base.py index 0c3236296..ec9d37ccf 100644 --- a/tests/test_models/test_base.py +++ b/tests/test_models/test_base.py @@ -191,21 +191,20 @@ def test_load_state_dict_from_path(self, """ definition = self.MODEL_DEFINITION_MAP[model_name] train_toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') - train_cfg = vak.config.parse.from_toml_path(train_toml_path) + train_cfg = vak.config.Config.from_toml_path(train_toml_path) # stuff we need just to be able to instantiate network labelmap = vak.common.labels.to_map(train_cfg.prep.labelset, map_unlabeled=True) - transform, target_transform = vak.transforms.defaults.get_default_transform( + item_transform = vak.transforms.defaults.get_default_transform( model_name, "train", transform_kwargs={}, ) train_dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( - dataset_path=train_cfg.train.dataset_path, + dataset_path=train_cfg.train.dataset.path, split="train", - window_size=train_cfg.train.train_dataset_params['window_size'], - transform=transform, - target_transform=target_transform, + window_size=train_cfg.train.dataset.params['window_size'], + item_transform=item_transform, ) input_shape = train_dataset.shape num_input_channels = input_shape[-3] @@ -216,7 +215,8 @@ def test_load_state_dict_from_path(self, ) # network is the one thing that has required args # and we also need to use its config from the toml file - model_config = vak.config.model.config_from_toml_path(train_toml_path, model_name) + cfg = vak.config.Config.from_toml_path(train_toml_path) + model_config = cfg.train.model.asdict() network = definition.network(num_classes=len(labelmap), num_input_channels=num_input_channels, num_freqbins=num_freqbins, @@ -225,7 +225,7 @@ def test_load_state_dict_from_path(self, model.to(device) eval_toml_path = specific_config_toml_path('eval', model_name, audio_format='cbin', annot_format='notmat') - eval_cfg = vak.config.parse.from_toml_path(eval_toml_path) + eval_cfg = vak.config.Config.from_toml_path(eval_toml_path) checkpoint_path = eval_cfg.eval.checkpoint_path # ---- actually test method diff --git a/tests/test_models/test_frame_classification_model.py b/tests/test_models/test_frame_classification_model.py index e77e84acf..c66dbcb47 100644 --- a/tests/test_models/test_frame_classification_model.py +++ b/tests/test_models/test_frame_classification_model.py @@ -86,7 +86,7 @@ def test_from_config(self, definition = vak.models.definition.validate(definition) model_name = definition.__name__.replace('Definition', '') toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) # stuff we need just to be able to instantiate network labelmap = vak.common.labels.to_map(cfg.prep.labelset, map_unlabeled=True) @@ -95,7 +95,7 @@ def test_from_config(self, vak.models.FrameClassificationModel, 'definition', definition, raising=False ) - config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + config = cfg.train.model.asdict() num_input_channels, num_freqbins = self.MOCK_INPUT_SHAPE[0], self.MOCK_INPUT_SHAPE[1] config["network"].update( diff --git a/tests/test_models/test_parametric_umap_model.py b/tests/test_models/test_parametric_umap_model.py index 0e255fed1..eba4f77d1 100644 --- a/tests/test_models/test_parametric_umap_model.py +++ b/tests/test_models/test_parametric_umap_model.py @@ -87,13 +87,13 @@ def test_from_config( definition = vak.models.definition.validate(definition) model_name = definition.__name__.replace('Definition', '') toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) monkeypatch.setattr( vak.models.ParametricUMAPModel, 'definition', definition, raising=False ) - config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + config = cfg.train.model.asdict() config["network"].update( encoder=dict(input_shape=input_shape) ) diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index 6726ec09b..5b8902f9d 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -37,32 +37,27 @@ def test_predict_with_frame_classification_model( ) output_dir.mkdir() - options_to_change = [ - {"section": "PREDICT", "option": "output_dir", "value": str(output_dir)}, - {"section": "PREDICT", "option": "device", "value": device}, - {"section": "PREDICT", "option": "save_net_outputs", "value": save_net_outputs}, + keys_to_change = [ + {"table": "predict", "key": "output_dir", "value": str(output_dir)}, + {"table": "predict", "key": "device", "value": device}, + {"table": "predict", "key": "save_net_outputs", "value": save_net_outputs}, ] toml_path = specific_config_toml_path( config_type="predict", model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) + cfg = vak.config.Config.from_toml_path(toml_path) vak.predict.frame_classification.predict_with_frame_classification_model( - model_name=cfg.predict.model, - model_config=model_config, - dataset_path=cfg.predict.dataset_path, + model_config=cfg.predict.model.asdict(), + dataset_config=cfg.predict.dataset.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, - transform_params=cfg.predict.transform_params, - dataset_params=cfg.predict.dataset_params, - timebins_key=cfg.spect_params.timebins_key, + timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, @@ -78,8 +73,8 @@ def test_predict_with_frame_classification_model( Path(output_dir).glob(f"*{vak.common.constants.NET_OUTPUT_SUFFIX}") ) - metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(cfg.predict.dataset_path) - dataset_csv_path = cfg.predict.dataset_path / metadata.dataset_csv_filename + metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(cfg.predict.dataset.path) + dataset_csv_path = cfg.predict.dataset.path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) for spect_path in dataset_df.spect_path.values: @@ -94,9 +89,9 @@ def test_predict_with_frame_classification_model( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "PREDICT", "option": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, - {"section": "PREDICT", "option": "labelmap_path", "value": '/obviously/doesnt/exist/labelmap.json'}, - {"section": "PREDICT", "option": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, + {"table": "predict", "key": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, + {"table": "predict", "key": "labelmap_path", "value": '/obviously/doesnt/exist/labelmap.json'}, + {"table": "predict", "key": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, ] ) def test_predict_with_frame_classification_model_raises_file_not_found( @@ -112,9 +107,9 @@ def test_predict_with_frame_classification_model_raises_file_not_found( ) output_dir.mkdir() - options_to_change = [ - {"section": "PREDICT", "option": "output_dir", "value": str(output_dir)}, - {"section": "PREDICT", "option": "device", "value": device}, + keys_to_change = [ + {"table": "predict", "key": "output_dir", "value": str(output_dir)}, + {"table": "predict", "key": "device", "value": device}, path_option_to_change, ] toml_path = specific_config_toml_path( @@ -122,23 +117,18 @@ def test_predict_with_frame_classification_model_raises_file_not_found( model="TweetyNet", audio_format="cbin", annot_format="notmat", - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) + cfg = vak.config.Config.from_toml_path(toml_path) with pytest.raises(FileNotFoundError): vak.predict.frame_classification.predict_with_frame_classification_model( - model_name=cfg.predict.model, - model_config=model_config, - dataset_path=cfg.predict.dataset_path, + model_config=cfg.predict.model.asdict(), + dataset_config=cfg.predict.dataset.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, - transform_params=cfg.predict.transform_params, - dataset_params=cfg.predict.dataset_params, - timebins_key=cfg.spect_params.timebins_key, + timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, @@ -152,8 +142,8 @@ def test_predict_with_frame_classification_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "PREDICT", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, - {"section": "PREDICT", "option": "output_dir", "value": '/obviously/does/not/exist/output'}, + {"table": "predict", "key": ["dataset", "path"], "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "predict", "key": "output_dir", "value": '/obviously/does/not/exist/output'}, ] ) def test_predict_with_frame_classification_model_raises_not_a_directory( @@ -165,20 +155,20 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( """Test that core.eval raises NotADirectory when ``output_dir`` does not exist """ - options_to_change = [ + keys_to_change = [ path_option_to_change, - {"section": "PREDICT", "option": "device", "value": device}, + {"table": "predict", "key": "device", "value": device}, ] - if path_option_to_change["option"] != "output_dir": + if path_option_to_change["key"] != "output_dir": # need to make sure output_dir *does* exist # so we don't detect spurious NotADirectoryError and assume test passes output_dir = tmp_path.joinpath( f"test_predict_raises_not_a_directory" ) output_dir.mkdir() - options_to_change.append( - {"section": "PREDICT", "option": "output_dir", "value": str(output_dir)} + keys_to_change.append( + {"table": "predict", "key": "output_dir", "value": str(output_dir)} ) toml_path = specific_config_toml_path( @@ -186,22 +176,18 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( model="TweetyNet", audio_format="cbin", annot_format="notmat", - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) + cfg = vak.config.Config.from_toml_path(toml_path) with pytest.raises(NotADirectoryError): vak.predict.frame_classification.predict_with_frame_classification_model( - model_name=cfg.predict.model, - model_config=model_config, - dataset_path=cfg.predict.dataset_path, + model_config=cfg.predict.model.asdict(), + dataset_config=cfg.predict.dataset.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, - transform_params=cfg.predict.transform_params, - dataset_params=cfg.predict.dataset_params, - timebins_key=cfg.spect_params.timebins_key, + timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, diff --git a/tests/test_predict/test_predict.py b/tests/test_predict/test_predict.py index 98051ca80..820613ed4 100644 --- a/tests/test_predict/test_predict.py +++ b/tests/test_predict/test_predict.py @@ -26,9 +26,9 @@ def test_predict( ) output_dir.mkdir() - options_to_change = [ - {"section": "PREDICT", "option": "output_dir", "value": str(output_dir)}, - {"section": "PREDICT", "option": "device", "value": 'cpu'}, + keys_to_change = [ + {"table": "predict", "key": "output_dir", "value": str(output_dir)}, + {"table": "predict", "key": "device", "value": 'cpu'}, ] toml_path = specific_config_toml_path( @@ -36,25 +36,21 @@ def test_predict( model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) + cfg = vak.config.Config.from_toml_path(toml_path) results_path = tmp_path / 'results_path' results_path.mkdir() with mock.patch(predict_function_to_mock, autospec=True) as mock_predict_function: vak.predict.predict( - model_name=model_name, - model_config=model_config, - dataset_path=cfg.predict.dataset_path, + model_config=cfg.predict.model.asdict(), + dataset_config=cfg.predict.dataset.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, - transform_params=cfg.predict.transform_params, - dataset_params=cfg.predict.dataset_params, - timebins_key=cfg.spect_params.timebins_key, + timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, diff --git a/tests/test_prep/test_frame_classification/test_assign_samples_to_splits.py b/tests/test_prep/test_frame_classification/test_assign_samples_to_splits.py index d354dfd6a..87b9489eb 100644 --- a/tests/test_prep/test_frame_classification/test_assign_samples_to_splits.py +++ b/tests/test_prep/test_frame_classification/test_assign_samples_to_splits.py @@ -27,7 +27,7 @@ def test_assign_samples_to_splits( spect_format, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) # ---- set up ---- tmp_dataset_path = tmp_path / 'dataset_dir' diff --git a/tests/test_prep/test_frame_classification/test_frame_classification.py b/tests/test_prep/test_frame_classification/test_frame_classification.py index 31a264847..968f4d19b 100644 --- a/tests/test_prep/test_frame_classification/test_frame_classification.py +++ b/tests/test_prep/test_frame_classification/test_frame_classification.py @@ -80,10 +80,10 @@ def test_prep_frame_classification_dataset( ) output_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "PREP", - "option": "output_dir", + "table": "prep", + "key": "output_dir", "value": str(output_dir), }, ] @@ -93,9 +93,9 @@ def test_prep_frame_classification_dataset( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = config_type.lower() dataset_df, dataset_path = vak.prep.frame_classification.frame_classification.prep_frame_classification_dataset( @@ -104,7 +104,7 @@ def test_prep_frame_classification_dataset( purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -149,14 +149,14 @@ def test_prep_frame_classification_dataset_raises_when_labelset_required_but_is_ ) output_dir.mkdir() - options_to_change = [ - {"section": "PREP", - "option": "output_dir", + keys_to_change = [ + {"table": "prep", + "key": "output_dir", "value": str(output_dir), }, - {"section": "PREP", - "option": "labelset", - "value": "DELETE-OPTION", + {"table": "prep", + "key": "labelset", + "value": "DELETE-KEY", }, ] toml_path = specific_config_toml_path( @@ -165,9 +165,9 @@ def test_prep_frame_classification_dataset_raises_when_labelset_required_but_is_ audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = config_type.lower() with pytest.raises(ValueError): @@ -177,7 +177,7 @@ def test_prep_frame_classification_dataset_raises_when_labelset_required_but_is_ purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -214,15 +214,15 @@ def test_prep_frame_classification_dataset_with_single_audio_and_annot(source_te ) output_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "PREP", - "option": "data_dir", + "table": "prep", + "key": "data_dir", "value": str(data_dir), }, { - "section": "PREP", - "option": "output_dir", + "table": "prep", + "key": "output_dir", "value": str(output_dir), }, ] @@ -233,9 +233,9 @@ def test_prep_frame_classification_dataset_with_single_audio_and_annot(source_te audio_format='cbin', annot_format='notmat', spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = 'eval' dataset_df, dataset_path = vak.prep.frame_classification.frame_classification.prep_frame_classification_dataset( @@ -244,7 +244,7 @@ def test_prep_frame_classification_dataset_with_single_audio_and_annot(source_te purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -272,15 +272,15 @@ def test_prep_frame_classification_dataset_when_annot_has_single_segment(source_ ) output_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "PREP", - "option": "data_dir", + "table": "prep", + "key": "data_dir", "value": str(data_dir), }, { - "section": "PREP", - "option": "output_dir", + "table": "prep", + "key": "output_dir", "value": str(output_dir), }, ] @@ -291,9 +291,9 @@ def test_prep_frame_classification_dataset_when_annot_has_single_segment(source_ audio_format='cbin', annot_format='notmat', spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = 'eval' dataset_df, dataset_path = vak.prep.frame_classification.frame_classification.prep_frame_classification_dataset( @@ -302,7 +302,7 @@ def test_prep_frame_classification_dataset_when_annot_has_single_segment(source_ purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -318,8 +318,8 @@ def test_prep_frame_classification_dataset_when_annot_has_single_segment(source_ @pytest.mark.parametrize( "dir_option_to_change", [ - {"section": "PREP", "option": "data_dir", "value": '/obviously/does/not/exist/data'}, - {"section": "PREP", "option": "output_dir", "value": '/obviously/does/not/exist/output'}, + {"table": "prep", "key": "data_dir", "value": '/obviously/does/not/exist/data'}, + {"table": "prep", "key": "output_dir", "value": '/obviously/does/not/exist/output'}, ], ) def test_prep_frame_classification_dataset_raises_not_a_directory( @@ -338,9 +338,9 @@ def test_prep_frame_classification_dataset_raises_not_a_directory( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=dir_option_to_change, + keys_to_change=dir_option_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = "train" with pytest.raises(NotADirectoryError): @@ -350,7 +350,7 @@ def test_prep_frame_classification_dataset_raises_not_a_directory( purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -364,7 +364,7 @@ def test_prep_frame_classification_dataset_raises_not_a_directory( @pytest.mark.parametrize( "path_option_to_change", [ - {"section": "PREP", "option": "annot_file", "value": '/obviously/does/not/exist/annot.mat'}, + {"table": "prep", "key": "annot_file", "value": '/obviously/does/not/exist/annot.mat'}, ], ) def test_prep_frame_classification_dataset_raises_file_not_found( @@ -386,9 +386,9 @@ def test_prep_frame_classification_dataset_raises_file_not_found( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=path_option_to_change, + keys_to_change=path_option_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = "train" with pytest.raises(FileNotFoundError): @@ -398,7 +398,7 @@ def test_prep_frame_classification_dataset_raises_file_not_found( purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, diff --git a/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py b/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py index d43ca7380..a49bbf5ba 100644 --- a/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py +++ b/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py @@ -40,7 +40,7 @@ def test_get_or_make_source_files( spect_format, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) # ---- set up ---- tmp_dataset_path = tmp_path / 'dataset_dir' @@ -55,7 +55,7 @@ def test_get_or_make_source_files( cfg.prep.input_type, cfg.prep.audio_format, cfg.prep.spect_format, - cfg.spect_params, + cfg.prep.spect_params, tmp_dataset_path, cfg.prep.annot_format, cfg.prep.annot_file, @@ -77,7 +77,7 @@ def test_get_or_make_source_files( cfg.prep.input_type, cfg.prep.audio_format, cfg.prep.spect_format, - cfg.spect_params, + cfg.prep.spect_params, tmp_dataset_path, cfg.prep.annot_format, cfg.prep.annot_file, diff --git a/tests/test_prep/test_frame_classification/test_learncurve.py b/tests/test_prep/test_frame_classification/test_learncurve.py index 150e6483a..ef3105144 100644 --- a/tests/test_prep/test_frame_classification/test_learncurve.py +++ b/tests/test_prep/test_frame_classification/test_learncurve.py @@ -22,10 +22,10 @@ def test_make_index_vectors_for_each_subsets( ): root_results_dir = tmp_path.joinpath("tmp_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "LEARNCURVE", - "option": "root_results_dir", + "table": "learncurve", + "key": "root_results_dir", "value": str(root_results_dir), }, ] @@ -34,11 +34,11 @@ def test_make_index_vectors_for_each_subsets( model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) - dataset_path = cfg.learncurve.dataset_path + dataset_path = cfg.learncurve.dataset.path metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) @@ -134,10 +134,10 @@ def test_make_subsets_from_dataset_df( ): root_results_dir = tmp_path.joinpath("tmp_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "LEARNCURVE", - "option": "root_results_dir", + "table": "learncurve", + "key": "root_results_dir", "value": str(root_results_dir), }, ] @@ -146,11 +146,11 @@ def test_make_subsets_from_dataset_df( model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) - dataset_path = cfg.learncurve.dataset_path + dataset_path = cfg.learncurve.dataset.path metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) diff --git a/tests/test_prep/test_frame_classification/test_make_splits.py b/tests/test_prep/test_frame_classification/test_make_splits.py index 5d5ef11cf..11037a5c3 100644 --- a/tests/test_prep/test_frame_classification/test_make_splits.py +++ b/tests/test_prep/test_frame_classification/test_make_splits.py @@ -88,7 +88,7 @@ def test_make_splits(config_type, model_name, audio_format, spect_format, annot_ audio_format, spect_format, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) # ---- set up ---- tmp_dataset_path = tmp_path / 'dataset_dir' diff --git a/tests/test_prep/test_prep.py b/tests/test_prep/test_prep.py index 8e995f8bc..370cb6647 100644 --- a/tests/test_prep/test_prep.py +++ b/tests/test_prep/test_prep.py @@ -33,10 +33,10 @@ def test_prep( ) output_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "PREP", - "option": "output_dir", + "table": "prep", + "key": "output_dir", "value": str(output_dir), }, ] @@ -46,9 +46,9 @@ def test_prep( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = config_type.lower() # ---- test @@ -61,7 +61,7 @@ def test_prep( purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py index 9cddd7e17..f4e50ef46 100644 --- a/tests/test_train/test_frame_classification.py +++ b/tests/test_train/test_frame_classification.py @@ -46,9 +46,9 @@ def test_train_frame_classification_model( ): results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() - options_to_change = [ - {"section": "TRAIN", "option": "device", "value": device}, - {"section": "TRAIN", "option": "root_results_dir", "value": results_path} + keys_to_change = [ + {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "root_results_dir", "value": str(results_path)} ] toml_path = specific_config_toml_path( config_type="train", @@ -56,22 +56,16 @@ def test_train_frame_classification_model( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + cfg = vak.config.Config.from_toml_path(toml_path) vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset_path, + model_config=cfg.train.model.asdict(), + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, spect_scaler_path=cfg.train.spect_scaler_path, results_path=results_path, @@ -83,7 +77,7 @@ def test_train_frame_classification_model( device=cfg.train.device, ) - assert_train_output_matches_expected(cfg, cfg.train.model, results_path) + assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path) @pytest.mark.slow @@ -99,9 +93,9 @@ def test_continue_training( ): results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() - options_to_change = [ - {"section": "TRAIN", "option": "device", "value": device}, - {"section": "TRAIN", "option": "root_results_dir", "value": results_path} + keys_to_change = [ + {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "root_results_dir", "value": str(results_path)} ] toml_path = specific_config_toml_path( config_type="train_continue", @@ -109,22 +103,16 @@ def test_continue_training( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + cfg = vak.config.Config.from_toml_path(toml_path) vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset_path, + model_config=cfg.train.model.asdict(), + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, spect_scaler_path=cfg.train.spect_scaler_path, results_path=results_path, @@ -136,14 +124,14 @@ def test_continue_training( device=cfg.train.device, ) - assert_train_output_matches_expected(cfg, cfg.train.model, results_path) + assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path) @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "TRAIN", "option": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, - {"section": "TRAIN", "option": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, + {"table": "train", "key": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, + {"table": "train", "key": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, ] ) def test_train_raises_file_not_found( @@ -153,8 +141,8 @@ def test_train_raises_file_not_found( when one of the following does not exist: checkpoint_path, dataset_path, spect_scaler_path """ - options_to_change = [ - {"section": "TRAIN", "option": "device", "value": device}, + keys_to_change = [ + {"table": "train", "key": "device", "value": device}, path_option_to_change ] toml_path = specific_config_toml_path( @@ -163,25 +151,19 @@ def test_train_raises_file_not_found( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + cfg = vak.config.Config.from_toml_path(toml_path) results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() with pytest.raises(FileNotFoundError): vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset_path, + model_config=cfg.train.model.asdict(), + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, spect_scaler_path=cfg.train.spect_scaler_path, results_path=results_path, @@ -197,8 +179,8 @@ def test_train_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "TRAIN", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, - {"section": "TRAIN", "option": "root_results_dir", "value": '/obviously/doesnt/exist/results/'}, + {"table": "train", "key": ["dataset", "path"], "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "train", "key": "root_results_dir", "value": '/obviously/doesnt/exist/results/'}, ] ) def test_train_raises_not_a_directory( @@ -207,9 +189,9 @@ def test_train_raises_not_a_directory( """Test that core.train raises NotADirectory when directory does not exist """ - options_to_change = [ + keys_to_change = [ path_option_to_change, - {"section": "TRAIN", "option": "device", "value": device}, + {"table": "train", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -218,26 +200,20 @@ def test_train_raises_not_a_directory( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + cfg = vak.config.Config.from_toml_path(toml_path) # mock behavior of cli.train, building `results_path` from config option `root_results_dir` results_path = cfg.train.root_results_dir / 'results-dir-timestamp' with pytest.raises(NotADirectoryError): vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset_path, + model_config=cfg.train.model.asdict(), + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, spect_scaler_path=cfg.train.spect_scaler_path, results_path=results_path, diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index a64516e0a..509078c21 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -39,9 +39,9 @@ def test_train_parametric_umap_model( ): results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() - options_to_change = [ - {"section": "TRAIN", "option": "device", "value": device}, - {"section": "TRAIN", "option": "root_results_dir", "value": results_path} + keys_to_change = [ + {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "root_results_dir", "value": str(results_path)} ] toml_path = specific_config_toml_path( config_type="train", @@ -49,22 +49,16 @@ def test_train_parametric_umap_model( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + cfg = vak.config.Config.from_toml_path(toml_path) vak.train.parametric_umap.train_parametric_umap_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset_path, + model_config=cfg.train.model.asdict(), + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, results_path=results_path, shuffle=cfg.train.shuffle, @@ -73,13 +67,13 @@ def test_train_parametric_umap_model( device=cfg.train.device, ) - assert_train_output_matches_expected(cfg, cfg.train.model, results_path) + assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path) @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "TRAIN", "option": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, + {"table": "train", "key": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, ] ) def test_train_parametric_umap_model_raises_file_not_found( @@ -89,8 +83,8 @@ def test_train_parametric_umap_model_raises_file_not_found( raise FileNotFoundError when one of the following does not exist: checkpoint_path, dataset_path """ - options_to_change = [ - {"section": "TRAIN", "option": "device", "value": device}, + keys_to_change = [ + {"table": "train", "key": "device", "value": device}, path_option_to_change ] toml_path = specific_config_toml_path( @@ -99,25 +93,19 @@ def test_train_parametric_umap_model_raises_file_not_found( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + cfg = vak.config.Config.from_toml_path(toml_path) results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() with pytest.raises(FileNotFoundError): vak.train.parametric_umap.train_parametric_umap_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset_path, + model_config=cfg.train.model.asdict(), + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, results_path=results_path, shuffle=cfg.train.shuffle, @@ -130,8 +118,8 @@ def test_train_parametric_umap_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "TRAIN", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, - {"section": "TRAIN", "option": "root_results_dir", "value": '/obviously/doesnt/exist/results/'}, + {"table": "train", "key": ["dataset", "path"], "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "train", "key": "root_results_dir", "value": '/obviously/doesnt/exist/results/'}, ] ) def test_train_parametric_umap_model_raises_not_a_directory( @@ -140,9 +128,9 @@ def test_train_parametric_umap_model_raises_not_a_directory( """Test that core.train raises NotADirectory when directory does not exist """ - options_to_change = [ + keys_to_change = [ path_option_to_change, - {"section": "TRAIN", "option": "device", "value": device}, + {"table": "train", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -151,26 +139,21 @@ def test_train_parametric_umap_model_raises_not_a_directory( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + cfg = vak.config.Config.from_toml_path(toml_path) + model_config = cfg.train.model.asdict() # mock behavior of cli.train, building `results_path` from config option `root_results_dir` results_path = cfg.train.root_results_dir / 'results-dir-timestamp' with pytest.raises(NotADirectoryError): vak.train.parametric_umap.train_parametric_umap_model( - model_name=cfg.train.model, model_config=model_config, - dataset_path=cfg.train.dataset_path, + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, results_path=results_path, shuffle=cfg.train.shuffle, diff --git a/tests/test_train/test_train.py b/tests/test_train/test_train.py index 559853a24..b9038007e 100644 --- a/tests/test_train/test_train.py +++ b/tests/test_train/test_train.py @@ -29,13 +29,13 @@ def test_train( root_results_dir = tmp_path.joinpath("test_train_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "TRAIN", - "option": "root_results_dir", + "table": "train", + "key": "root_results_dir", "value": str(root_results_dir), }, - {"section": "TRAIN", "option": "device", "value": 'cpu'}, + {"table": "train", "key": "device", "value": 'cpu'}, ] toml_path = specific_config_toml_path( @@ -44,26 +44,20 @@ def test_train( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + cfg = vak.config.Config.from_toml_path(toml_path) results_path = tmp_path / 'results_path' results_path.mkdir() with mock.patch(train_function_to_mock, autospec=True) as mock_train_function: vak.train.train( - model_name=model_name, - model_config=model_config, - dataset_path=cfg.train.dataset_path, + model_config=cfg.train.model.asdict(), + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, spect_scaler_path=cfg.train.spect_scaler_path, results_path=results_path,