Skip to content

Commit

Permalink
ENH: fix frame classification model to work with BioSoundSegBench (#774)
Browse files Browse the repository at this point in the history
This PR consists mainly of fixes needed for FrameClassificationModel to work with the BioSoundSegBench dataset.

* Hack a learning rate scheduler into FrameClassificationModel

* Add boundary_labels parameter to transforms.frame_labels.transforms.PostProcess.__call__

* Add background_label to post_tfm_kwargs in src/vak/eval/frame_classification.py

* In `validation_step` of FrameClassificationModel, use boundary labels when they are present to post-process multi-class frame labels'

* Unpack `dataset_path` from dataset_config in code block for built-in datasets in eval/frame_classification.py, to make sure this variable exists when we build the DataFrame with eval results"

* Make variable `frame_dur` inside code block for built-in datasets inside eval/frame_classification.py so this variable exists when we get the post-processing transform

* Pass `background_label` into transforms.frame_labels.PostProcess inside eval/frame_classification.py, using `constants.DEFAULT_BACKGROUND_LABEL`

* Fix how we call self.manual_backwward in FrameClassificationModel to handle the case when the loss function returns a dict

* In FrameClassificationModel.validation_step, convert boundary_preds to numpy when we pass them in to self.post_tfm

* In FrameClassificationModel.validation_step, when logging accuracy, call it 'val_multi_acc' to distinguish from boundary_acc and for consistency with val_multi_acc_tfm

* Change how we get and log frame_dur in train/frame_classification.py so we have it as a separate variable; will use for post_tfm kwargs when we add those later

* Change one-line summary of __call__ method for frame_labels.transforms.PostProcess

* BUG: Ensure boundary_labels is 1d in post-process transform, fix #767

* Fix what metric we use for learning rate scheduler: use val_multi_acc for models with multiple accuracies

* Remove trainer module from common, code is used only for frame classification model

* Add get_trainer and get_callbacks to train/frame_classification.py, fix so that we monitor 'val_multi_acc' when a model has multiple targets, and just 'val_acc' otherwise

* Add missing self.manual_backward in training_step of FrameClassificationModel

* Fix how we determine whether there are multiple targets and what to monitor in train/frame_classification.py

* Fix how we validate boundary_labels in transforms.frame_labels.functional.postprocess -- don't if boundary_labels is None

* Fix vak/predict/frame_classification.py to handle edge case where no non-background segments are predicted for any sample in dataset

* Revise comment

* Catch edge case in transforms.frame_labels.functional.boundary_inds_from_frame_boundary_labels

* Add minimal unit tests for vak.transforms.frame_labels.functional.boundary_inds_from_boundary_labels

* Remove learning rate scheduler for now
  • Loading branch information
NickleDave authored Sep 9, 2024
1 parent f4afc5c commit 5ebe407
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 114 deletions.
2 changes: 0 additions & 2 deletions src/vak/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
tensorboard,
timebins,
timenow,
trainer,
typing,
validators,
)
Expand All @@ -39,7 +38,6 @@
"tensorboard",
"timebins",
"timenow",
"trainer",
"typing",
"validators",
]
88 changes: 0 additions & 88 deletions src/vak/common/trainer.py

This file was deleted.

11 changes: 10 additions & 1 deletion src/vak/eval/frame_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.utils.data

from .. import datapipes, datasets, models, transforms
from ..common import validators
from ..common import constants, validators
from ..datapipes.frame_classification import InferDatapipe

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -154,12 +154,20 @@ def eval_frame_classification_model(
)
# ---- *yes* using a built-in dataset ------------------------------------------------------------------------------
else:
# next line, we don't use dataset path in this code block,
# but we need it below when we build the DataFrame with eval results.
# we're unpacking it here just as we do above with a prep'd dataset
dataset_path = pathlib.Path(dataset_config["path"])
dataset_config["params"]["return_padding_mask"] = True
val_dataset = datasets.get(
dataset_config,
split=split,
frames_standardizer=frames_standardizer,
)
frame_dur = val_dataset.frame_dur
logger.info(
f"Duration of a frame in dataset, in seconds: {frame_dur}",
)

val_loader = torch.utils.data.DataLoader(
dataset=val_dataset,
Expand All @@ -179,6 +187,7 @@ def eval_frame_classification_model(
if post_tfm_kwargs:
post_tfm = transforms.frame_labels.PostProcess(
timebin_dur=frame_dur,
background_label=labelmap[constants.DEFAULT_BACKGROUND_LABEL],
**post_tfm_kwargs,
)
else:
Expand Down
19 changes: 12 additions & 7 deletions src/vak/models/frame_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def __init__(
:const:`vak.common.constants.DEFAULT_BACKGROUND_LABEL`.
"""
super().__init__()

self.network = network
self.loss = loss
self.optimizer = optimizer
Expand Down Expand Up @@ -365,9 +364,15 @@ def validation_step(self, batch: tuple, batch_idx: int):
class_preds_str = self.to_labels_eval(class_preds.cpu().numpy())

if self.post_tfm:
class_preds_tfm = self.post_tfm(
class_preds.cpu().numpy(),
)
if target_types == ("multi_frame_labels",):
class_preds_tfm = self.post_tfm(
class_preds.cpu().numpy(),
)
elif target_types == ("multi_frame_labels", "boundary_frame_labels"):
class_preds_tfm = self.post_tfm(
class_preds.cpu().numpy(),
boundary_labels=boundary_preds.cpu().numpy(),
)
class_preds_tfm_str = self.to_labels_eval(class_preds_tfm)
# convert back to tensor so we can compute accuracy
class_preds_tfm = torch.from_numpy(class_preds_tfm).to(
Expand Down Expand Up @@ -395,8 +400,8 @@ def validation_step(self, batch: tuple, batch_idx: int):
loss = self.loss(
class_logits,
boundary_logits,
batch["multi_frame_labels"],
batch["boundary_frame_labels"],
target["multi_frame_labels"],
target["boundary_frame_labels"],
)
if isinstance(loss, torch.Tensor):
self.log(
Expand Down Expand Up @@ -435,7 +440,7 @@ def validation_step(self, batch: tuple, batch_idx: int):
)
else:
self.log(
f"val_{metric_name}",
f"val_multi_{metric_name}",
metric_callable(
class_preds, target["multi_frame_labels"]
),
Expand Down
7 changes: 5 additions & 2 deletions src/vak/predict/frame_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,11 @@ def predict_with_frame_classification_model(
annot_path=annot_csv_path.name,
)
annots.append(annot)

if all([isinstance(annot, crowsetta.Annotation) for annot in annots]):
if len(annots) < 1:
# catch edge case where nothing was predicted
# FIXME: this should have columns that match GenericSeq
pd.DataFrame.from_records([]).to_csv(annot_csv_path)
elif all([isinstance(annot, crowsetta.Annotation) for annot in annots]):
generic_seq = crowsetta.formats.seq.GenericSeq(annots=annots)
generic_seq.to_file(annot_path=annot_csv_path)
elif all([isinstance(annot, AnnotationDataFrame) for annot in annots]):
Expand Down
117 changes: 108 additions & 9 deletions src/vak/train/frame_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import pathlib
import shutil

import lightning
import joblib
import pandas as pd
import torch.utils.data

from .. import datapipes, datasets, models, transforms
from ..common import validators
from ..common.trainer import get_default_trainer
from ..datapipes.frame_classification import InferDatapipe, TrainDatapipe

logger = logging.getLogger(__name__)
Expand All @@ -25,6 +25,92 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float:
return df[df["split"] == split]["duration"].sum()


def get_train_callbacks(
ckpt_root: str | pathlib.Path,
ckpt_step: int,
patience: int,
checkpoint_monitor: str = "val_acc",
early_stopping_monitor: str = "val_acc",
early_stopping_mode: str = "max",
) -> list[lightning.pytorch.callbacks.Callback]:
ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
dirpath=ckpt_root,
filename="checkpoint",
every_n_train_steps=ckpt_step,
save_last=True,
verbose=True,
)
ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint"
ckpt_callback.FILE_EXTENSION = ".pt"

val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
monitor=checkpoint_monitor,
dirpath=ckpt_root,
save_top_k=1,
mode="max",
filename="max-val-acc-checkpoint",
auto_insert_metric_name=False,
verbose=True,
)
val_ckpt_callback.FILE_EXTENSION = ".pt"

early_stopping = lightning.pytorch.callbacks.EarlyStopping(
mode=early_stopping_mode,
monitor=early_stopping_monitor,
patience=patience,
verbose=True,
)

return [ckpt_callback, val_ckpt_callback, early_stopping]


def get_trainer(
accelerator: str,
devices: int | list[int],
max_steps: int,
log_save_dir: str | pathlib.Path,
val_step: int,
callback_kwargs: dict | None = None,
) -> lightning.pytorch.Trainer:
"""Returns an instance of :class:`lightning.pytorch.Trainer`
with a default set of callbacks.
Used by :func:`vak.train.frame_classification`.
The default set of callbacks is provided by
:func:`get_default_train_callbacks`.
Parameters
----------
accelerator : str
devices : int, list of int
max_steps : int
log_save_dir : str, pathlib.Path
val_step : int
default_callback_kwargs : dict, optional
Returns
-------
trainer : lightning.pytorch.Trainer
"""
if callback_kwargs:
callbacks = get_train_callbacks(**callback_kwargs)
else:
callbacks = None

logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir)

trainer = lightning.pytorch.Trainer(
accelerator=accelerator,
devices=devices,
callbacks=callbacks,
val_check_interval=val_step,
max_steps=max_steps,
logger=logger,
)
return trainer


def train_frame_classification_model(
model_config: dict,
dataset_config: dict,
Expand Down Expand Up @@ -245,8 +331,9 @@ def train_frame_classification_model(
dataset_config,
split="train",
)
frame_dur = train_dataset.frame_dur
logger.info(
f"Duration of a frame in dataset, in seconds: {train_dataset.frame_dur}",
f"Duration of a frame in dataset, in seconds: {frame_dur}",
)
# copy labelmap from dataset to new results_path
labelmap = train_dataset.labelmap
Expand Down Expand Up @@ -334,18 +421,30 @@ def train_frame_classification_model(
ckpt_root.mkdir()
logger.info(f"training {model_name}")
max_steps = num_epochs * len(train_loader)
default_callback_kwargs = {
"ckpt_root": ckpt_root,
"ckpt_step": ckpt_step,
"patience": patience,
}
trainer = get_default_trainer(
if isinstance(dataset_config["params"]["target_type"], list) and all([isinstance(target_type, str) for target_type in dataset_config["params"]["target_type"]]):
multiple_targets = True
elif isinstance(dataset_config["params"]["target_type"], str):
multiple_targets = False
else:
raise ValueError(
f'Invalid value for dataset_config["params"]["target_type"]: {dataset_config["params"]["target_type"], list}'
)

callback_kwargs = dict(
ckpt_root=ckpt_root,
ckpt_step=ckpt_step,
patience=patience,
checkpoint_monitor="val_multi_acc" if multiple_targets else "val_acc",
early_stopping_monitor="val_multi_acc" if multiple_targets else "val_acc",
early_stopping_mode="max",
)
trainer = get_trainer(
accelerator=trainer_config["accelerator"],
devices=trainer_config["devices"],
max_steps=max_steps,
log_save_dir=results_model_root,
val_step=val_step,
default_callback_kwargs=default_callback_kwargs,
callback_kwargs=callback_kwargs,
)
train_time_start = datetime.datetime.now()
logger.info(f"Training start time: {train_time_start.isoformat()}")
Expand Down
14 changes: 11 additions & 3 deletions src/vak/transforms/frame_labels/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,17 @@ def boundary_inds_from_boundary_labels(
If ``True``, and the first index of ``boundary_labels`` is not classified as a boundary,
force it to be a boundary.
"""
boundary_labels = row_or_1d(boundary_labels)
boundary_inds = np.nonzero(boundary_labels)[0]

if boundary_inds[0] != 0 and force_boundary_first_ind:
# force there to be a boundary at index 0
np.insert(boundary_inds, 0, 0)
if force_boundary_first_ind:
if len(boundary_inds) == 0:
# handle edge case where no boundaries were predicted
boundary_inds = np.array([0]) # replace with a single boundary, at index 0
else:
if boundary_inds[0] != 0:
# force there to be a boundary at index 0
np.insert(boundary_inds, 0, 0)

return boundary_inds

Expand Down Expand Up @@ -531,6 +537,8 @@ def postprocess(
Vector of frame labels after post-processing is applied.
"""
frame_labels = row_or_1d(frame_labels)
if boundary_labels is not None:
boundary_labels = row_or_1d(boundary_labels)

# handle the case when all time bins are predicted to be unlabeled
# see https://github.com/NickleDave/vak/issues/383
Expand Down
Loading

0 comments on commit 5ebe407

Please sign in to comment.