diff --git a/docs/docs/configurations.md b/docs/docs/configurations.md index 1aeb4541..9ca83cce 100644 --- a/docs/docs/configurations.md +++ b/docs/docs/configurations.md @@ -244,16 +244,16 @@ class TrainConfig(ZambaBaseModel) model_name: zamba.models.config.ModelEnum = , dry_run: Union[bool, int] = False, batch_size: int = 2, - auto_lr_find: bool = True, + auto_lr_find: bool = False, backbone_finetune_config: zamba.models.config.BackboneFinetuneConfig = - BackboneFinetuneConfig(unfreeze_backbone_at_epoch=15, + BackboneFinetuneConfig(unfreeze_backbone_at_epoch=5, backbone_initial_ratio_lr=0.01, multiplier=1, pre_train_bn=False, train_bn=False, verbose=True), gpus: int = 0, num_workers: int = 3, max_epochs: int = None, early_stopping_config: zamba.models.config.EarlyStoppingConfig = - EarlyStoppingConfig(monitor='val_macro_f1', patience=3, + EarlyStoppingConfig(monitor='val_macro_f1', patience=5, verbose=True, mode='max'), weight_download_region: zamba.models.utils.RegionEnum = 'us', split_proportions: Dict[str, int] = {'train': 3, 'val': 1, 'holdout': 1}, @@ -299,11 +299,11 @@ The batch size to use for training. Defaults to `2` #### `auto_lr_find (bool, optional)` -Whether to run a [learning rate finder algorithm](https://arxiv.org/abs/1506.01186) when calling `pytorch_lightning.trainer.tune()` to find the optimal initial learning rate. See the PyTorch Lightning [docs](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#auto-lr-find) for more details. Defaults to `True` +Whether to run a [learning rate finder algorithm](https://arxiv.org/abs/1506.01186) when calling `pytorch_lightning.trainer.tune()` to try to find an optimal initial learning rate. The learning rate finder is not guaranteed to find a good learning rate; depending on the dataset, it can select a learning rate that leads to poor model training. Use with caution. See the PyTorch Lightning [docs](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#auto-lr-find) for more details. Defaults to `False`. #### `backbone_finetune_config (zamba.models.config.BackboneFinetuneConfig, optional)` -Set parameters to finetune a backbone model to align with the current learning rate. Derived from Pytorch Lightning's built-in [`BackboneFinetuning`](https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/finetuning.html). The default values are specified in the [`BackboneFinetuneConfig` class](api-reference/models-config.md#zamba.models.config.BackboneFinetuneConfig): `BackboneFinetuneConfig(unfreeze_backbone_at_epoch=15, backbone_initial_ratio_lr=0.01, multiplier=1, pre_train_bn=False, train_bn=False, verbose=True)` +Set parameters to finetune a backbone model to align with the current learning rate. Derived from Pytorch Lightning's built-in [`BackboneFinetuning`](https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/finetuning.html). The default values are specified in the [`BackboneFinetuneConfig` class](api-reference/models-config.md#zamba.models.config.BackboneFinetuneConfig): `BackboneFinetuneConfig(unfreeze_backbone_at_epoch=5, backbone_initial_ratio_lr=0.01, multiplier=1, pre_train_bn=False, train_bn=False, verbose=True)` #### `gpus (int, optional)` @@ -319,7 +319,7 @@ The maximum number of epochs to run during training. Defaults to `None` #### `early_stopping_config (zamba.models.config.EarlyStoppingConfig, optional)` -Parameters to pass to Pytorch lightning's [`EarlyStopping`](https://github.com/PyTorchLightning/pytorch-lightning/blob/c7451b3ccf742b0e8971332caf2e041ceabd9fe8/pytorch_lightning/callbacks/early_stopping.py#L35) to monitor a metric during model training and stop training when the metric stops improving. The default values are specified in the [`EarlyStoppingConfig` class](api-reference/models-config.md#zamba.models.config.EarlyStoppingConfig): `EarlyStoppingConfig(monitor='val_macro_f1', patience=3, verbose=True, mode='max')` +Parameters to pass to Pytorch lightning's [`EarlyStopping`](https://github.com/PyTorchLightning/pytorch-lightning/blob/c7451b3ccf742b0e8971332caf2e041ceabd9fe8/pytorch_lightning/callbacks/early_stopping.py#L35) to monitor a metric during model training and stop training when the metric stops improving. The default values are specified in the [`EarlyStoppingConfig` class](api-reference/models-config.md#zamba.models.config.EarlyStoppingConfig): `EarlyStoppingConfig(monitor='val_macro_f1', patience=5, verbose=True, mode='max')` #### `weight_download_region [us|eu|asia]` @@ -351,4 +351,4 @@ Whether the species outputted by the model should be all zamba species. If you w #### `model_cache_dir (Path, optional)` -Cache directory where downloaded model weights will be saved. If None and the MODEL_CACHE_DIR environment variable is not set, will use your default cache directory, which is often an automatic temp directory at `~/.cache/zamba`. Defaults to `None`. \ No newline at end of file +Cache directory where downloaded model weights will be saved. If None and the MODEL_CACHE_DIR environment variable is not set, will use your default cache directory, which is often an automatic temp directory at `~/.cache/zamba`. Defaults to `None`. diff --git a/zamba/models/config.py b/zamba/models/config.py index 9086249b..faba49dc 100644 --- a/zamba/models/config.py +++ b/zamba/models/config.py @@ -189,7 +189,7 @@ class BackboneFinetuneConfig(ZambaBaseModel): Args: unfreeze_backbone_at_epoch (int, optional): Epoch at which the backbone - will be unfrozen. Defaults to 15. + will be unfrozen. Defaults to 5. backbone_initial_ratio_lr (float, optional): Used to scale down the backbone learning rate compared to rest of model. Defaults to 0.01. multiplier (int or float, optional): Multiply the learning rate by a constant @@ -202,7 +202,7 @@ class BackboneFinetuneConfig(ZambaBaseModel): Defaults to True. """ - unfreeze_backbone_at_epoch: Optional[int] = 15 + unfreeze_backbone_at_epoch: Optional[int] = 5 backbone_initial_ratio_lr: Optional[float] = 0.01 multiplier: Optional[Union[int, float]] = 1 pre_train_bn: Optional[bool] = False # freeze batch norm layers prior to finetuning @@ -217,7 +217,7 @@ class EarlyStoppingConfig(ZambaBaseModel): monitor (str): Metric to be monitored. Options are "val_macro_f1" or "val_loss". Defaults to "val_macro_f1". patience (int): Number of epochs with no improvement after which training - will be stopped. Defaults to 3. + will be stopped. Defaults to 5. verbose (bool): Verbosity mode. Defaults to True. mode (str, optional): Options are "min" or "max". In "min" mode, training will stop when the quantity monitored has stopped decreasing and in @@ -226,7 +226,7 @@ class EarlyStoppingConfig(ZambaBaseModel): """ monitor: MonitorEnum = "val_macro_f1" - patience: int = 3 + patience: int = 5 verbose: bool = True mode: Optional[str] = None @@ -300,10 +300,13 @@ class TrainConfig(ZambaBaseModel): Defaults to False. batch_size (int): Batch size to use for training. Defaults to 2. auto_lr_find (bool): Use a learning rate finder algorithm when calling - trainer.tune() to find a optimal initial learning rate. Defaults to True. + trainer.tune() to try to find an optimal initial learning rate. Defaults to + False. The learning rate finder is not guaranteed to find a good learning + rate; depending on the dataset, it can select a learning rate that leads to + poor model training. Use with caution. backbone_finetune_params (BackboneFinetuneConfig, optional): Set parameters to finetune a backbone model to align with the current learning rate. - Defaults to a BackboneFinetuneConfig(unfreeze_backbone_at_epoch=15, + Defaults to a BackboneFinetuneConfig(unfreeze_backbone_at_epoch=5, backbone_initial_ratio_lr=0.01, multiplier=1, pre_train_bn=False, train_bn=False, verbose=True). gpus (int): Number of GPUs to train on applied per node. @@ -316,7 +319,7 @@ class TrainConfig(ZambaBaseModel): early_stopping_config (EarlyStoppingConfig, optional): Configuration for early stopping, which monitors a metric during training and stops training when the metric stops improving. Defaults to EarlyStoppingConfig(monitor='val_macro_f1', - patience=3, verbose=True, mode='max'). + patience=5, verbose=True, mode='max'). weight_download_region (str): s3 region to download pretrained weights from. Options are "us" (United States), "eu" (European Union), or "asia" (Asia Pacific). Defaults to "us". @@ -354,7 +357,7 @@ class TrainConfig(ZambaBaseModel): model_name: Optional[ModelEnum] = ModelEnum.time_distributed dry_run: Union[bool, int] = False batch_size: int = 2 - auto_lr_find: bool = True + auto_lr_find: bool = False backbone_finetune_config: Optional[BackboneFinetuneConfig] = BackboneFinetuneConfig() gpus: int = GPUS_AVAILABLE num_workers: int = 3 diff --git a/zamba/pytorch_lightning/utils.py b/zamba/pytorch_lightning/utils.py index 5b7865c0..5de96058 100644 --- a/zamba/pytorch_lightning/utils.py +++ b/zamba/pytorch_lightning/utils.py @@ -96,7 +96,7 @@ def train_dataloader(self) -> Optional[torch.utils.data.DataLoader]: shuffle=True, multiprocessing_context=self.multiprocessing_context, prefetch_factor=self.prefetch_factor, - persistent_workers=True, + persistent_workers=self.num_workers > 0, ) def val_dataloader(self) -> Optional[torch.utils.data.DataLoader]: @@ -108,7 +108,7 @@ def val_dataloader(self) -> Optional[torch.utils.data.DataLoader]: shuffle=False, multiprocessing_context=self.multiprocessing_context, prefetch_factor=self.prefetch_factor, - persistent_workers=True, + persistent_workers=self.num_workers > 0, ) def test_dataloader(self) -> Optional[torch.utils.data.DataLoader]: @@ -120,7 +120,7 @@ def test_dataloader(self) -> Optional[torch.utils.data.DataLoader]: shuffle=False, multiprocessing_context=self.multiprocessing_context, prefetch_factor=self.prefetch_factor, - persistent_workers=True, + persistent_workers=self.num_workers > 0, ) def predict_dataloader(self) -> Optional[torch.utils.data.DataLoader]: