Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrading examples to lightning 2.0 #60

Open
mikesol opened this issue May 31, 2023 · 0 comments
Open

Upgrading examples to lightning 2.0 #60

mikesol opened this issue May 31, 2023 · 0 comments

Comments

@mikesol
Copy link

mikesol commented May 31, 2023

Hi!

I tried to get the compressor example up and running and, in the process, migrated it as best I can to lightning 2.0. There were several breaking changes, and as I've never used lightning before and am not familiar with auraloss yet, I'm not exactly sure if it worked. But it's training and the loss is going down, so that's a good sign!

The branch is here: https://github.com/mikesol/auraloss/tree/compressor-test.

The command I used locally is:

python examples/compressor/train_comp.py \
   fit \
   --data.root_dir SignalTrain_LA2A_Dataset_1.1 \
   --trainer.max_epochs 20 \
   --model.kernel_size 15 \
   --model.channel_width 32 \
   --model.dilation_growth 2 \
   --data.preload False \
   --data.num_workers 8 \
   --data.shuffle True \
   --data.batch_size 32 \
   --model.nparams 2 \
   --data.length 32768

And the log so far shows:

(.venv) 21:42 meeshkan-abel@Abel:~/mike/auraloss$ python examples/compressor/train_comp.py    fit    --data.root_dir SignalTrain_LA2A_Dataset_1.1    --trainer.max_epochs 20    --model.kernel_size 15    --model.channel_width 32    --model.dilation_growth 2    --data.preload False    --data.num_workers 8    --data.shuffle True    --data.batch_size 32    --model.nparams 2    --data.length 32768
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/fabric/utilities/seed.py:39: UserWarning: No seed found, seed set to 3866398735
  rank_zero_warn(f"No seed found, seed set to {seed}")
Global seed set to 3866398735
Located 94285 examples totaling 19.5 hr in the train subset.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
  warning_cache.warn(
[rank: 0] Global seed set to 3866398735
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
[rank: 1] Global seed set to 3866398735
Located 94285 examples totaling 19.5 hr in the train subset.
[rank: 1] Global seed set to 3866398735
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

Located 94285 examples totaling 19.5 hr in the train subset.
Located 94285 examples totaling 19.5 hr in the train subset.
Located 94285 examples totaling 19.5 hr in the train subset.
Located 94285 examples totaling 19.5 hr in the train subset.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name    | Type                     | Params
------------------------------------------------------
0  | l1      | L1Loss                   | 0     
1  | esr     | ESRLoss                  | 0     
2  | dc      | DCLoss                   | 0     
3  | logcosh | LogCoshLoss              | 0     
4  | sisdr   | SISDRLoss                | 0     
5  | stft    | STFTLoss                 | 0     
6  | mrstft  | MultiResolutionSTFTLoss  | 0     
7  | rrstft  | RandomResolutionSTFTLoss | 0     
8  | gen     | Sequential               | 10.5 K
9  | blocks  | ModuleList               | 221 K 
10 | output  | Conv1d                   | 33    
------------------------------------------------------
232 K     Trainable params
0         Non-trainable params
232 K     Total params
0.930     Total estimated model params size (MB)
Sanity Checking DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.30it/s]/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/L1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/ESR', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/DC', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/LogCosh', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/SI-SDR', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/STFT', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/MRSTFT', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
Epoch 0:  24%|█████████████████████▎                                                                  | 357/1474 [09:08<28:34,  1.54s/it, v_num=5, train_loss_step=0.691]Epoch 0:  31%|███████████████████████████                                                             | 453/1474 [11:21<25:37,  1.51s/it, v_num=5, train_loss_step=0.858]

If you're interested in updating to lightning 2.0, I'd be happy to help out. The branch definitely isn't in good enough shape yet for a PR, but maybe if you take a look at the diff you'll see what elements needed tweaking and we could take it form there. Thanks & lemme know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant