From 5ebe4074a8a0fc9e7bcb86bd1f2eb582847f045a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Mon, 9 Sep 2024 15:11:26 -0400 Subject: [PATCH] ENH: fix frame classification model to work with BioSoundSegBench (#774) This PR consists mainly of fixes needed for FrameClassificationModel to work with the BioSoundSegBench dataset. * Hack a learning rate scheduler into FrameClassificationModel * Add boundary_labels parameter to transforms.frame_labels.transforms.PostProcess.__call__ * Add background_label to post_tfm_kwargs in src/vak/eval/frame_classification.py * In `validation_step` of FrameClassificationModel, use boundary labels when they are present to post-process multi-class frame labels' * Unpack `dataset_path` from dataset_config in code block for built-in datasets in eval/frame_classification.py, to make sure this variable exists when we build the DataFrame with eval results" * Make variable `frame_dur` inside code block for built-in datasets inside eval/frame_classification.py so this variable exists when we get the post-processing transform * Pass `background_label` into transforms.frame_labels.PostProcess inside eval/frame_classification.py, using `constants.DEFAULT_BACKGROUND_LABEL` * Fix how we call self.manual_backwward in FrameClassificationModel to handle the case when the loss function returns a dict * In FrameClassificationModel.validation_step, convert boundary_preds to numpy when we pass them in to self.post_tfm * In FrameClassificationModel.validation_step, when logging accuracy, call it 'val_multi_acc' to distinguish from boundary_acc and for consistency with val_multi_acc_tfm * Change how we get and log frame_dur in train/frame_classification.py so we have it as a separate variable; will use for post_tfm kwargs when we add those later * Change one-line summary of __call__ method for frame_labels.transforms.PostProcess * BUG: Ensure boundary_labels is 1d in post-process transform, fix #767 * Fix what metric we use for learning rate scheduler: use val_multi_acc for models with multiple accuracies * Remove trainer module from common, code is used only for frame classification model * Add get_trainer and get_callbacks to train/frame_classification.py, fix so that we monitor 'val_multi_acc' when a model has multiple targets, and just 'val_acc' otherwise * Add missing self.manual_backward in training_step of FrameClassificationModel * Fix how we determine whether there are multiple targets and what to monitor in train/frame_classification.py * Fix how we validate boundary_labels in transforms.frame_labels.functional.postprocess -- don't if boundary_labels is None * Fix vak/predict/frame_classification.py to handle edge case where no non-background segments are predicted for any sample in dataset * Revise comment * Catch edge case in transforms.frame_labels.functional.boundary_inds_from_frame_boundary_labels * Add minimal unit tests for vak.transforms.frame_labels.functional.boundary_inds_from_boundary_labels * Remove learning rate scheduler for now --- src/vak/common/__init__.py | 2 - src/vak/common/trainer.py | 88 ------------- src/vak/eval/frame_classification.py | 11 +- src/vak/models/frame_classification_model.py | 19 +-- src/vak/predict/frame_classification.py | 7 +- src/vak/train/frame_classification.py | 117 ++++++++++++++++-- src/vak/transforms/frame_labels/functional.py | 14 ++- src/vak/transforms/frame_labels/transforms.py | 7 +- .../test_frame_labels/test_functional.py | 39 ++++++ 9 files changed, 190 insertions(+), 114 deletions(-) delete mode 100644 src/vak/common/trainer.py diff --git a/src/vak/common/__init__.py b/src/vak/common/__init__.py index c5be9ccfd..84e1190b3 100644 --- a/src/vak/common/__init__.py +++ b/src/vak/common/__init__.py @@ -21,7 +21,6 @@ tensorboard, timebins, timenow, - trainer, typing, validators, ) @@ -39,7 +38,6 @@ "tensorboard", "timebins", "timenow", - "trainer", "typing", "validators", ] diff --git a/src/vak/common/trainer.py b/src/vak/common/trainer.py deleted file mode 100644 index cd12d03d4..000000000 --- a/src/vak/common/trainer.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -import pathlib - -import lightning - - -def get_default_train_callbacks( - ckpt_root: str | pathlib.Path, - ckpt_step: int, - patience: int, -): - ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( - dirpath=ckpt_root, - filename="checkpoint", - every_n_train_steps=ckpt_step, - save_last=True, - verbose=True, - ) - ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" - ckpt_callback.FILE_EXTENSION = ".pt" - - val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( - monitor="val_acc", - dirpath=ckpt_root, - save_top_k=1, - mode="max", - filename="max-val-acc-checkpoint", - auto_insert_metric_name=False, - verbose=True, - ) - val_ckpt_callback.FILE_EXTENSION = ".pt" - - early_stopping = lightning.pytorch.callbacks.EarlyStopping( - mode="max", - monitor="val_acc", - patience=patience, - verbose=True, - ) - - return [ckpt_callback, val_ckpt_callback, early_stopping] - - -def get_default_trainer( - accelerator: str, - devices: int | list[int], - max_steps: int, - log_save_dir: str | pathlib.Path, - val_step: int, - default_callback_kwargs: dict | None = None, -) -> lightning.pytorch.Trainer: - """Returns an instance of :class:`lightning.pytorch.Trainer` - with a default set of callbacks. - - Used by :func:`vak.train.frame_classification`. - The default set of callbacks is provided by - :func:`get_default_train_callbacks`. - - Parameters - ---------- - accelerator : str - devices : int, list of int - max_steps : int - log_save_dir : str, pathlib.Path - val_step : int - default_callback_kwargs : dict, optional - - Returns - ------- - trainer : lightning.pytorch.Trainer - - """ - if default_callback_kwargs: - callbacks = get_default_train_callbacks(**default_callback_kwargs) - else: - callbacks = None - - logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir) - - trainer = lightning.pytorch.Trainer( - accelerator=accelerator, - devices=devices, - callbacks=callbacks, - val_check_interval=val_step, - max_steps=max_steps, - logger=logger, - ) - return trainer diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index 12688288f..665a0bb4f 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -14,7 +14,7 @@ import torch.utils.data from .. import datapipes, datasets, models, transforms -from ..common import validators +from ..common import constants, validators from ..datapipes.frame_classification import InferDatapipe logger = logging.getLogger(__name__) @@ -154,12 +154,20 @@ def eval_frame_classification_model( ) # ---- *yes* using a built-in dataset ------------------------------------------------------------------------------ else: + # next line, we don't use dataset path in this code block, + # but we need it below when we build the DataFrame with eval results. + # we're unpacking it here just as we do above with a prep'd dataset + dataset_path = pathlib.Path(dataset_config["path"]) dataset_config["params"]["return_padding_mask"] = True val_dataset = datasets.get( dataset_config, split=split, frames_standardizer=frames_standardizer, ) + frame_dur = val_dataset.frame_dur + logger.info( + f"Duration of a frame in dataset, in seconds: {frame_dur}", + ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, @@ -179,6 +187,7 @@ def eval_frame_classification_model( if post_tfm_kwargs: post_tfm = transforms.frame_labels.PostProcess( timebin_dur=frame_dur, + background_label=labelmap[constants.DEFAULT_BACKGROUND_LABEL], **post_tfm_kwargs, ) else: diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 2da299cf3..2159c014d 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -130,7 +130,6 @@ def __init__( :const:`vak.common.constants.DEFAULT_BACKGROUND_LABEL`. """ super().__init__() - self.network = network self.loss = loss self.optimizer = optimizer @@ -365,9 +364,15 @@ def validation_step(self, batch: tuple, batch_idx: int): class_preds_str = self.to_labels_eval(class_preds.cpu().numpy()) if self.post_tfm: - class_preds_tfm = self.post_tfm( - class_preds.cpu().numpy(), - ) + if target_types == ("multi_frame_labels",): + class_preds_tfm = self.post_tfm( + class_preds.cpu().numpy(), + ) + elif target_types == ("multi_frame_labels", "boundary_frame_labels"): + class_preds_tfm = self.post_tfm( + class_preds.cpu().numpy(), + boundary_labels=boundary_preds.cpu().numpy(), + ) class_preds_tfm_str = self.to_labels_eval(class_preds_tfm) # convert back to tensor so we can compute accuracy class_preds_tfm = torch.from_numpy(class_preds_tfm).to( @@ -395,8 +400,8 @@ def validation_step(self, batch: tuple, batch_idx: int): loss = self.loss( class_logits, boundary_logits, - batch["multi_frame_labels"], - batch["boundary_frame_labels"], + target["multi_frame_labels"], + target["boundary_frame_labels"], ) if isinstance(loss, torch.Tensor): self.log( @@ -435,7 +440,7 @@ def validation_step(self, batch: tuple, batch_idx: int): ) else: self.log( - f"val_{metric_name}", + f"val_multi_{metric_name}", metric_callable( class_preds, target["multi_frame_labels"] ), diff --git a/src/vak/predict/frame_classification.py b/src/vak/predict/frame_classification.py index eff70ef48..148de01ff 100644 --- a/src/vak/predict/frame_classification.py +++ b/src/vak/predict/frame_classification.py @@ -468,8 +468,11 @@ def predict_with_frame_classification_model( annot_path=annot_csv_path.name, ) annots.append(annot) - - if all([isinstance(annot, crowsetta.Annotation) for annot in annots]): + if len(annots) < 1: + # catch edge case where nothing was predicted + # FIXME: this should have columns that match GenericSeq + pd.DataFrame.from_records([]).to_csv(annot_csv_path) + elif all([isinstance(annot, crowsetta.Annotation) for annot in annots]): generic_seq = crowsetta.formats.seq.GenericSeq(annots=annots) generic_seq.to_file(annot_path=annot_csv_path) elif all([isinstance(annot, AnnotationDataFrame) for annot in annots]): diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index f83f6518f..74f6a269e 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -8,13 +8,13 @@ import pathlib import shutil +import lightning import joblib import pandas as pd import torch.utils.data from .. import datapipes, datasets, models, transforms from ..common import validators -from ..common.trainer import get_default_trainer from ..datapipes.frame_classification import InferDatapipe, TrainDatapipe logger = logging.getLogger(__name__) @@ -25,6 +25,92 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float: return df[df["split"] == split]["duration"].sum() +def get_train_callbacks( + ckpt_root: str | pathlib.Path, + ckpt_step: int, + patience: int, + checkpoint_monitor: str = "val_acc", + early_stopping_monitor: str = "val_acc", + early_stopping_mode: str = "max", +) -> list[lightning.pytorch.callbacks.Callback]: + ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( + dirpath=ckpt_root, + filename="checkpoint", + every_n_train_steps=ckpt_step, + save_last=True, + verbose=True, + ) + ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" + ckpt_callback.FILE_EXTENSION = ".pt" + + val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint( + monitor=checkpoint_monitor, + dirpath=ckpt_root, + save_top_k=1, + mode="max", + filename="max-val-acc-checkpoint", + auto_insert_metric_name=False, + verbose=True, + ) + val_ckpt_callback.FILE_EXTENSION = ".pt" + + early_stopping = lightning.pytorch.callbacks.EarlyStopping( + mode=early_stopping_mode, + monitor=early_stopping_monitor, + patience=patience, + verbose=True, + ) + + return [ckpt_callback, val_ckpt_callback, early_stopping] + + +def get_trainer( + accelerator: str, + devices: int | list[int], + max_steps: int, + log_save_dir: str | pathlib.Path, + val_step: int, + callback_kwargs: dict | None = None, +) -> lightning.pytorch.Trainer: + """Returns an instance of :class:`lightning.pytorch.Trainer` + with a default set of callbacks. + + Used by :func:`vak.train.frame_classification`. + The default set of callbacks is provided by + :func:`get_default_train_callbacks`. + + Parameters + ---------- + accelerator : str + devices : int, list of int + max_steps : int + log_save_dir : str, pathlib.Path + val_step : int + default_callback_kwargs : dict, optional + + Returns + ------- + trainer : lightning.pytorch.Trainer + + """ + if callback_kwargs: + callbacks = get_train_callbacks(**callback_kwargs) + else: + callbacks = None + + logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir) + + trainer = lightning.pytorch.Trainer( + accelerator=accelerator, + devices=devices, + callbacks=callbacks, + val_check_interval=val_step, + max_steps=max_steps, + logger=logger, + ) + return trainer + + def train_frame_classification_model( model_config: dict, dataset_config: dict, @@ -245,8 +331,9 @@ def train_frame_classification_model( dataset_config, split="train", ) + frame_dur = train_dataset.frame_dur logger.info( - f"Duration of a frame in dataset, in seconds: {train_dataset.frame_dur}", + f"Duration of a frame in dataset, in seconds: {frame_dur}", ) # copy labelmap from dataset to new results_path labelmap = train_dataset.labelmap @@ -334,18 +421,30 @@ def train_frame_classification_model( ckpt_root.mkdir() logger.info(f"training {model_name}") max_steps = num_epochs * len(train_loader) - default_callback_kwargs = { - "ckpt_root": ckpt_root, - "ckpt_step": ckpt_step, - "patience": patience, - } - trainer = get_default_trainer( + if isinstance(dataset_config["params"]["target_type"], list) and all([isinstance(target_type, str) for target_type in dataset_config["params"]["target_type"]]): + multiple_targets = True + elif isinstance(dataset_config["params"]["target_type"], str): + multiple_targets = False + else: + raise ValueError( + f'Invalid value for dataset_config["params"]["target_type"]: {dataset_config["params"]["target_type"], list}' + ) + + callback_kwargs = dict( + ckpt_root=ckpt_root, + ckpt_step=ckpt_step, + patience=patience, + checkpoint_monitor="val_multi_acc" if multiple_targets else "val_acc", + early_stopping_monitor="val_multi_acc" if multiple_targets else "val_acc", + early_stopping_mode="max", + ) + trainer = get_trainer( accelerator=trainer_config["accelerator"], devices=trainer_config["devices"], max_steps=max_steps, log_save_dir=results_model_root, val_step=val_step, - default_callback_kwargs=default_callback_kwargs, + callback_kwargs=callback_kwargs, ) train_time_start = datetime.datetime.now() logger.info(f"Training start time: {train_time_start.isoformat()}") diff --git a/src/vak/transforms/frame_labels/functional.py b/src/vak/transforms/frame_labels/functional.py index 09e9eaf2e..580d3229a 100644 --- a/src/vak/transforms/frame_labels/functional.py +++ b/src/vak/transforms/frame_labels/functional.py @@ -401,11 +401,17 @@ def boundary_inds_from_boundary_labels( If ``True``, and the first index of ``boundary_labels`` is not classified as a boundary, force it to be a boundary. """ + boundary_labels = row_or_1d(boundary_labels) boundary_inds = np.nonzero(boundary_labels)[0] - if boundary_inds[0] != 0 and force_boundary_first_ind: - # force there to be a boundary at index 0 - np.insert(boundary_inds, 0, 0) + if force_boundary_first_ind: + if len(boundary_inds) == 0: + # handle edge case where no boundaries were predicted + boundary_inds = np.array([0]) # replace with a single boundary, at index 0 + else: + if boundary_inds[0] != 0: + # force there to be a boundary at index 0 + np.insert(boundary_inds, 0, 0) return boundary_inds @@ -531,6 +537,8 @@ def postprocess( Vector of frame labels after post-processing is applied. """ frame_labels = row_or_1d(frame_labels) + if boundary_labels is not None: + boundary_labels = row_or_1d(boundary_labels) # handle the case when all time bins are predicted to be unlabeled # see https://github.com/NickleDave/vak/issues/383 diff --git a/src/vak/transforms/frame_labels/transforms.py b/src/vak/transforms/frame_labels/transforms.py index 0dab0c504..024a767ba 100644 --- a/src/vak/transforms/frame_labels/transforms.py +++ b/src/vak/transforms/frame_labels/transforms.py @@ -24,6 +24,7 @@ from __future__ import annotations import numpy as np +import numpy.typing as npt from . import functional as F @@ -258,8 +259,9 @@ def __init__( self.min_segment_dur = min_segment_dur self.majority_vote = majority_vote - def __call__(self, frame_labels: np.ndarray) -> np.ndarray: - """Convert vector of frame labels into labels. + def __call__(self, frame_labels: np.ndarray, boundary_labels: npt.NDArray | None = None) -> np.ndarray: + """Apply post-processing transformations + to a vector of frame labels. Parameters ---------- @@ -280,4 +282,5 @@ def __call__(self, frame_labels: np.ndarray) -> np.ndarray: self.background_label, self.min_segment_dur, self.majority_vote, + boundary_labels, ) diff --git a/tests/test_transforms/test_frame_labels/test_functional.py b/tests/test_transforms/test_frame_labels/test_functional.py index 9e9013eb6..eed67504f 100644 --- a/tests/test_transforms/test_frame_labels/test_functional.py +++ b/tests/test_transforms/test_frame_labels/test_functional.py @@ -246,6 +246,45 @@ def test_to_segments_real_data( assert np.all(np.abs(annot.seq.offsets_s - offsets_s) < MAX_ABS_DIFF) +@pytest.mark.parametrize( + "boundary_labels, boundary_inds_expected", + [ + ( + np.array([1,0,0,0,1,0,0]), + np.array([0,4]) + ), + ] +) +def test_boundary_inds_from_boundary_labels(boundary_labels, boundary_inds_expected): + boundary_inds = vak.transforms.frame_labels.boundary_inds_from_boundary_labels( + boundary_labels + ) + assert np.array_equal(boundary_inds, boundary_inds_expected) + + +@pytest.mark.parametrize( + "boundary_labels, expected_exception", + [ + # 3-d array should raise a ValueError, needs to be row or 1-d + ( + np.array([[[1,0,0,0,1,0,0]]]), + ValueError + ), + # column vector should raise a ValueError, needs to be row or 1-d + ( + np.array([[1],[0],[0]]), + ValueError + ) + + ] +) +def test_boundary_inds_from_boundary_labels(boundary_labels, expected_exception): + with pytest.raises(expected_exception): + vak.transforms.frame_labels.functional.segment_inds_list_from_boundary_labels( + boundary_labels + ) + + @pytest.mark.parametrize( "frame_labels, seg_inds_list_expected", [