Skip to content

Commit

Permalink
Change default backbone finetuning, auto_lr_find, and fix persistent …
Browse files Browse the repository at this point in the history
…workers bug (#148)

* Default patience from 3 to 5

* Default for auto_lr_find to False

* Change default unfreeze_backbone_at_epoch to 5

* Fix persistent_workers bug

* Revert changes to templates
  • Loading branch information
r-b-g-b authored Oct 22, 2021
1 parent c41ae7e commit c8a354d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
14 changes: 7 additions & 7 deletions docs/docs/configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,16 @@ class TrainConfig(ZambaBaseModel)
model_name: zamba.models.config.ModelEnum = <ModelEnum.time_distributed: 'time_distributed'>,
dry_run: Union[bool, int] = False,
batch_size: int = 2,
auto_lr_find: bool = True,
auto_lr_find: bool = False,
backbone_finetune_config: zamba.models.config.BackboneFinetuneConfig =
BackboneFinetuneConfig(unfreeze_backbone_at_epoch=15,
BackboneFinetuneConfig(unfreeze_backbone_at_epoch=5,
backbone_initial_ratio_lr=0.01, multiplier=1,
pre_train_bn=False, train_bn=False, verbose=True),
gpus: int = 0,
num_workers: int = 3,
max_epochs: int = None,
early_stopping_config: zamba.models.config.EarlyStoppingConfig =
EarlyStoppingConfig(monitor='val_macro_f1', patience=3,
EarlyStoppingConfig(monitor='val_macro_f1', patience=5,
verbose=True, mode='max'),
weight_download_region: zamba.models.utils.RegionEnum = 'us',
split_proportions: Dict[str, int] = {'train': 3, 'val': 1, 'holdout': 1},
Expand Down Expand Up @@ -299,11 +299,11 @@ The batch size to use for training. Defaults to `2`

#### `auto_lr_find (bool, optional)`

Whether to run a [learning rate finder algorithm](https://arxiv.org/abs/1506.01186) when calling `pytorch_lightning.trainer.tune()` to find the optimal initial learning rate. See the PyTorch Lightning [docs](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#auto-lr-find) for more details. Defaults to `True`
Whether to run a [learning rate finder algorithm](https://arxiv.org/abs/1506.01186) when calling `pytorch_lightning.trainer.tune()` to try to find an optimal initial learning rate. The learning rate finder is not guaranteed to find a good learning rate; depending on the dataset, it can select a learning rate that leads to poor model training. Use with caution. See the PyTorch Lightning [docs](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#auto-lr-find) for more details. Defaults to `False`.

#### `backbone_finetune_config (zamba.models.config.BackboneFinetuneConfig, optional)`

Set parameters to finetune a backbone model to align with the current learning rate. Derived from Pytorch Lightning's built-in [`BackboneFinetuning`](https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/finetuning.html). The default values are specified in the [`BackboneFinetuneConfig` class](api-reference/models-config.md#zamba.models.config.BackboneFinetuneConfig): `BackboneFinetuneConfig(unfreeze_backbone_at_epoch=15, backbone_initial_ratio_lr=0.01, multiplier=1, pre_train_bn=False, train_bn=False, verbose=True)`
Set parameters to finetune a backbone model to align with the current learning rate. Derived from Pytorch Lightning's built-in [`BackboneFinetuning`](https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/finetuning.html). The default values are specified in the [`BackboneFinetuneConfig` class](api-reference/models-config.md#zamba.models.config.BackboneFinetuneConfig): `BackboneFinetuneConfig(unfreeze_backbone_at_epoch=5, backbone_initial_ratio_lr=0.01, multiplier=1, pre_train_bn=False, train_bn=False, verbose=True)`

#### `gpus (int, optional)`

Expand All @@ -319,7 +319,7 @@ The maximum number of epochs to run during training. Defaults to `None`

#### `early_stopping_config (zamba.models.config.EarlyStoppingConfig, optional)`

Parameters to pass to Pytorch lightning's [`EarlyStopping`](https://github.com/PyTorchLightning/pytorch-lightning/blob/c7451b3ccf742b0e8971332caf2e041ceabd9fe8/pytorch_lightning/callbacks/early_stopping.py#L35) to monitor a metric during model training and stop training when the metric stops improving. The default values are specified in the [`EarlyStoppingConfig` class](api-reference/models-config.md#zamba.models.config.EarlyStoppingConfig): `EarlyStoppingConfig(monitor='val_macro_f1', patience=3, verbose=True, mode='max')`
Parameters to pass to Pytorch lightning's [`EarlyStopping`](https://github.com/PyTorchLightning/pytorch-lightning/blob/c7451b3ccf742b0e8971332caf2e041ceabd9fe8/pytorch_lightning/callbacks/early_stopping.py#L35) to monitor a metric during model training and stop training when the metric stops improving. The default values are specified in the [`EarlyStoppingConfig` class](api-reference/models-config.md#zamba.models.config.EarlyStoppingConfig): `EarlyStoppingConfig(monitor='val_macro_f1', patience=5, verbose=True, mode='max')`

#### `weight_download_region [us|eu|asia]`

Expand Down Expand Up @@ -351,4 +351,4 @@ Whether the species outputted by the model should be all zamba species. If you w

#### `model_cache_dir (Path, optional)`

Cache directory where downloaded model weights will be saved. If None and the MODEL_CACHE_DIR environment variable is not set, will use your default cache directory, which is often an automatic temp directory at `~/.cache/zamba`. Defaults to `None`.
Cache directory where downloaded model weights will be saved. If None and the MODEL_CACHE_DIR environment variable is not set, will use your default cache directory, which is often an automatic temp directory at `~/.cache/zamba`. Defaults to `None`.
19 changes: 11 additions & 8 deletions zamba/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class BackboneFinetuneConfig(ZambaBaseModel):
Args:
unfreeze_backbone_at_epoch (int, optional): Epoch at which the backbone
will be unfrozen. Defaults to 15.
will be unfrozen. Defaults to 5.
backbone_initial_ratio_lr (float, optional): Used to scale down the backbone
learning rate compared to rest of model. Defaults to 0.01.
multiplier (int or float, optional): Multiply the learning rate by a constant
Expand All @@ -202,7 +202,7 @@ class BackboneFinetuneConfig(ZambaBaseModel):
Defaults to True.
"""

unfreeze_backbone_at_epoch: Optional[int] = 15
unfreeze_backbone_at_epoch: Optional[int] = 5
backbone_initial_ratio_lr: Optional[float] = 0.01
multiplier: Optional[Union[int, float]] = 1
pre_train_bn: Optional[bool] = False # freeze batch norm layers prior to finetuning
Expand All @@ -217,7 +217,7 @@ class EarlyStoppingConfig(ZambaBaseModel):
monitor (str): Metric to be monitored. Options are "val_macro_f1" or
"val_loss". Defaults to "val_macro_f1".
patience (int): Number of epochs with no improvement after which training
will be stopped. Defaults to 3.
will be stopped. Defaults to 5.
verbose (bool): Verbosity mode. Defaults to True.
mode (str, optional): Options are "min" or "max". In "min" mode, training
will stop when the quantity monitored has stopped decreasing and in
Expand All @@ -226,7 +226,7 @@ class EarlyStoppingConfig(ZambaBaseModel):
"""

monitor: MonitorEnum = "val_macro_f1"
patience: int = 3
patience: int = 5
verbose: bool = True
mode: Optional[str] = None

Expand Down Expand Up @@ -300,10 +300,13 @@ class TrainConfig(ZambaBaseModel):
Defaults to False.
batch_size (int): Batch size to use for training. Defaults to 2.
auto_lr_find (bool): Use a learning rate finder algorithm when calling
trainer.tune() to find a optimal initial learning rate. Defaults to True.
trainer.tune() to try to find an optimal initial learning rate. Defaults to
False. The learning rate finder is not guaranteed to find a good learning
rate; depending on the dataset, it can select a learning rate that leads to
poor model training. Use with caution.
backbone_finetune_params (BackboneFinetuneConfig, optional): Set parameters
to finetune a backbone model to align with the current learning rate.
Defaults to a BackboneFinetuneConfig(unfreeze_backbone_at_epoch=15,
Defaults to a BackboneFinetuneConfig(unfreeze_backbone_at_epoch=5,
backbone_initial_ratio_lr=0.01, multiplier=1, pre_train_bn=False,
train_bn=False, verbose=True).
gpus (int): Number of GPUs to train on applied per node.
Expand All @@ -316,7 +319,7 @@ class TrainConfig(ZambaBaseModel):
early_stopping_config (EarlyStoppingConfig, optional): Configuration for
early stopping, which monitors a metric during training and stops training
when the metric stops improving. Defaults to EarlyStoppingConfig(monitor='val_macro_f1',
patience=3, verbose=True, mode='max').
patience=5, verbose=True, mode='max').
weight_download_region (str): s3 region to download pretrained weights from.
Options are "us" (United States), "eu" (European Union), or "asia"
(Asia Pacific). Defaults to "us".
Expand Down Expand Up @@ -354,7 +357,7 @@ class TrainConfig(ZambaBaseModel):
model_name: Optional[ModelEnum] = ModelEnum.time_distributed
dry_run: Union[bool, int] = False
batch_size: int = 2
auto_lr_find: bool = True
auto_lr_find: bool = False
backbone_finetune_config: Optional[BackboneFinetuneConfig] = BackboneFinetuneConfig()
gpus: int = GPUS_AVAILABLE
num_workers: int = 3
Expand Down
6 changes: 3 additions & 3 deletions zamba/pytorch_lightning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def train_dataloader(self) -> Optional[torch.utils.data.DataLoader]:
shuffle=True,
multiprocessing_context=self.multiprocessing_context,
prefetch_factor=self.prefetch_factor,
persistent_workers=True,
persistent_workers=self.num_workers > 0,
)

def val_dataloader(self) -> Optional[torch.utils.data.DataLoader]:
Expand All @@ -108,7 +108,7 @@ def val_dataloader(self) -> Optional[torch.utils.data.DataLoader]:
shuffle=False,
multiprocessing_context=self.multiprocessing_context,
prefetch_factor=self.prefetch_factor,
persistent_workers=True,
persistent_workers=self.num_workers > 0,
)

def test_dataloader(self) -> Optional[torch.utils.data.DataLoader]:
Expand All @@ -120,7 +120,7 @@ def test_dataloader(self) -> Optional[torch.utils.data.DataLoader]:
shuffle=False,
multiprocessing_context=self.multiprocessing_context,
prefetch_factor=self.prefetch_factor,
persistent_workers=True,
persistent_workers=self.num_workers > 0,
)

def predict_dataloader(self) -> Optional[torch.utils.data.DataLoader]:
Expand Down

0 comments on commit c8a354d

Please sign in to comment.