Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] fix: updates the training sampling strategy to complete the last batch #538

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

wiitt
Copy link
Collaborator

@wiitt wiitt commented Aug 30, 2024

Fixes #438

PR Goal?

Updates sampling strategy in training and complete last batches with random samples from other batches instead of dropping last batches.

Fixes?

Fixes #438

Feedback sought?

If a model works with this new sampler. If it produces results that are better or at least not worse after training.

Priority?

Low

Tests added?

No tests added, but it would be good to have some testing of this sampler.

How to test?

Place a breakpoint and inspect the composition of a last batch in an epoch. Check that the number of batches correspond to expectations during training. Train a model in a scenario when difference between dropping and keeping last batches is noticeable (e.g. very small dataset or a dataset where samples in last batch have unique phonemes).

Confidence?

Low. This code wasn't properly tested.

Version change?

No. Can be a part of a larger update.

Related PRs?

No.

Copy link

semanticdiff-com bot commented Aug 30, 2024

Review changes with  SemanticDiff

Changed Files
File Status
  everyvoice/dataloader/__init__.py  60% smaller
  everyvoice/dataloader/oversampler.py  0% smaller

@wiitt wiitt marked this pull request as draft August 30, 2024 21:52
@wiitt wiitt requested a review from roedoejet August 30, 2024 21:52
Copy link
Contributor

github-actions bot commented Aug 30, 2024

CLI load time: 0:00.26
Pull Request HEAD: c3b6c71e91a6d0613876cb172f5b0510ee4c612b
Imports that take more than 0.1 s:
import time: self [us] | cumulative | imported package
import time:      1011 |     101921 |     typer.main
import time:       303 |     120869 |   typer
import time:      7894 |     198903 | everyvoice.cli

Copy link

codecov bot commented Aug 30, 2024

Codecov Report

Attention: Patch coverage is 19.51220% with 33 lines in your changes missing coverage. Please review.

Project coverage is 75.72%. Comparing base (7e5cc06) to head (c3b6c71).

Files with missing lines Patch % Lines
everyvoice/dataloader/oversampler.py 20.00% 28 Missing ⚠️
everyvoice/dataloader/__init__.py 16.66% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #538      +/-   ##
==========================================
- Coverage   76.34%   75.72%   -0.62%     
==========================================
  Files          47       48       +1     
  Lines        3483     3522      +39     
  Branches      479      486       +7     
==========================================
+ Hits         2659     2667       +8     
- Misses        721      752      +31     
  Partials      103      103              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@marctessier
Copy link
Collaborator

I have a bunch of experiments setup using the latest version of EV and this PR. I am unable to get this PR to start training.

I originaly setup an experiment using the "z" character / removed from training / added and I was having a hard to even confirming that this bug exist since / was able to synth the z sound...

I was in the process of setting up an other experiment using the "J" character till I was noticing that none of the training using this PR was able to successfully train . I tried using LJ and other datasets with no modifications . I always get the error below when I use this PR to train.

I am not convinced this PR is ready to be release.

See attached logs :
moh_538.e3559722.txt

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/TxT2SPEECH/EveryVoice_538_dev.vg_4 │
│ 38-batch-size/everyvoice/model/feature_prediction/FastSpeech2_lightning/fs2/ │
│ cli/train.py:28 in train                                                     │
│                                                                              │
│   25 │                                                                       │
│   26 │   model_kwargs = {"lang2id": lang2id, "speaker2id": speaker2id, "stat │
│   27 │                                                                       │
│ ❱ 28 │   train_base_command(                                                 │
│   29 │   │   model_config=FastSpeech2Config,                                 │
│   30 │   │   model=FastSpeech2,                                              │
│   31 │   │   data_module=FastSpeech2DataModule,                              │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/TxT2SPEECH/EveryVoice_538_dev.vg_4 │
│ 38-batch-size/everyvoice/base_cli/helpers.py:288 in train_base_command       │
│                                                                              │
│   285 │   │   model_obj = model(config, **model_kwargs)                      │
│   286 │   │   logger.info(f"Model's architecture\n{model_obj}")              │
│   287 │   │   tensorboard_logger.log_hyperparams(config.model_dump())        │
│ ❱ 288 │   │   trainer.fit(model_obj, data)                                   │
│   289 │   else:                                                              │
│   290 │   │   model_obj = model.load_from_checkpoint(last_ckpt)              │
│   291 │   │   logger.info(f"Model's architecture\n{model_obj}")              │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/trainer/tr │
│ ainer.py:539 in fit                                                          │
│                                                                              │
│    536 │   │   self.state.fn = TrainerFn.FITTING                             │
│    537 │   │   self.state.status = TrainerStatus.RUNNING                     │
│    538 │   │   self.training = True                                          │
│ ❱  539 │   │   call._call_and_handle_interrupt(                              │
│    540 │   │   │   self, self._fit_impl, model, train_dataloaders, val_datal │
│    541 │   │   )                                                             │
│    542                                                                       │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/trainer/ca │
│ ll.py:46 in _call_and_handle_interrupt                                       │
│                                                                              │
│    43 │   """                                                                │
│    44 │   try:                                                               │
│    45 │   │   if trainer.strategy.launcher is not None:                      │
│ ❱  46 │   │   │   return trainer.strategy.launcher.launch(trainer_fn, *args, │
│    47 │   │   return trainer_fn(*args, **kwargs)                             │
│    48 │                                                                      │
│    49 │   except _TunerExitException:                                        │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/strategies │
│ /launchers/subprocess_script.py:105 in launch                                │
│                                                                              │
│   102 │   │   │   _launch_process_observer(self.procs)                       │
│   103 │   │                                                                  │
│   104 │   │   _set_num_threads_if_needed(num_processes=self.num_processes)   │
│ ❱ 105 │   │   return function(*args, **kwargs)                               │
│   106 │                                                                      │
│   107 │   @override                                                          │
│   108 │   def kill(self, signum: _SIGNUM) -> None:                           │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/trainer/tr │
│ ainer.py:575 in _fit_impl                                                    │
│                                                                              │
│    572 │   │   │   model_provided=True,                                      │
│    573 │   │   │   model_connected=self.lightning_module is not None,        │
│    574 │   │   )                                                             │
│ ❱  575 │   │   self._run(model, ckpt_path=ckpt_path)                         │
│    576 │   │                                                                 │
│    577 │   │   assert self.state.stopped                                     │
│    578 │   │   self.training = False                                         │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/trainer/tr │
│ ainer.py:982 in _run                                                         │
│                                                                              │
│    979 │   │   # ----------------------------                                │
│    980 │   │   # RUN THE TRAINER                                             │
│    981 │   │   # ----------------------------                                │
│ ❱  982 │   │   results = self._run_stage()                                   │
│    983 │   │                                                                 │
│    984 │   │   # ----------------------------                                │
│    985 │   │   # POST-Training CLEAN UP                                      │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/trainer/tr │
│ ainer.py:1024 in _run_stage                                                  │
│                                                                              │
│   1021 │   │   │   return self.predict_loop.run()                            │
│   1022 │   │   if self.training:                                             │
│   1023 │   │   │   with isolate_rng():                                       │
│ ❱ 1024 │   │   │   │   self._run_sanity_check()                              │
│   1025 │   │   │   with torch.autograd.set_detect_anomaly(self._detect_anoma │
│   1026 │   │   │   │   self.fit_loop.run()                                   │
│   1027 │   │   │   return None                                               │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/trainer/tr │
│ ainer.py:1053 in _run_sanity_check                                           │
│                                                                              │
│   1050 │   │   │   call._call_callback_hooks(self, "on_sanity_check_start")  │
│   1051 │   │   │                                                             │
│   1052 │   │   │   # run eval step                                           │
│ ❱ 1053 │   │   │   val_loop.run()                                            │
│   1054 │   │   │                                                             │
│   1055 │   │   │   call._call_callback_hooks(self, "on_sanity_check_end")    │
│   1056                                                                       │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/loops/util │
│ ities.py:179 in _decorator                                                   │
│                                                                              │
│   176 │   │   else:                                                          │
│   177 │   │   │   context_manager = torch.no_grad                            │
│   178 │   │   with context_manager():                                        │
│ ❱ 179 │   │   │   return loop_run(self, *args, **kwargs)                     │
│   180 │                                                                      │
│   181 │   return _decorator                                                  │
│   182                                                                        │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/loops/eval │
│ uation_loop.py:137 in run                                                    │
│                                                                              │
│   134 │   │   │   │   │   dataloader_idx = data_fetcher._dataloader_idx      │
│   135 │   │   │   │   else:                                                  │
│   136 │   │   │   │   │   dataloader_iter = None                             │
│ ❱ 137 │   │   │   │   │   batch, batch_idx, dataloader_idx = next(data_fetch │
│   138 │   │   │   │   if previous_dataloader_idx != dataloader_idx:          │
│   139 │   │   │   │   │   # the dataloader has changed, notify the logger co │
│   140 │   │   │   │   │   self._store_dataloader_outputs()                   │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/loops/fetc │
│ hers.py:134 in __next__                                                      │
│                                                                              │
│   131 │   │   │   │   self.done = not self.batches                           │
│   132 │   │   elif not self.done:                                            │
│   133 │   │   │   # this will run only when no pre-fetching was done.        │
│ ❱ 134 │   │   │   batch = super().__next__()                                 │
│   135 │   │   else:                                                          │
│   136 │   │   │   # the iterator is empty                                    │
│   137 │   │   │   raise StopIteration                                        │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/loops/fetc │
│ hers.py:61 in __next__                                                       │
│                                                                              │
│    58 │   │   assert self.iterator is not None                               │
│    59 │   │   self._start_profiler()                                         │
│    60 │   │   try:                                                           │
│ ❱  61 │   │   │   batch = next(self.iterator)                                │
│    62 │   │   except StopIteration:                                          │
│    63 │   │   │   self.done = True                                           │
│    64 │   │   │   raise                                                      │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/utilities/ │
│ combined_loader.py:341 in __next__                                           │
│                                                                              │
│   338 │                                                                      │
│   339 │   def __next__(self) -> _ITERATOR_RETURN:                            │
│   340 │   │   assert self._iterator is not None                              │
│ ❱ 341 │   │   out = next(self._iterator)                                     │
│   342 │   │   if isinstance(self._iterator, _Sequential):                    │
│   343 │   │   │   return out                                                 │
│   344 │   │   out, batch_idx, dataloader_idx = out                           │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/pytorch_lightning/utilities/ │
│ combined_loader.py:142 in __next__                                           │
│                                                                              │
│   139 │   │   │   │   │   raise StopIteration                                │
│   140 │   │                                                                  │
│   141 │   │   try:                                                           │
│ ❱ 142 │   │   │   out = next(self.iterators[0])                              │
│   143 │   │   except StopIteration:                                          │
│   144 │   │   │   # try the next iterator                                    │
│   145 │   │   │   self._use_next_iterator()                                  │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/torch/utils/data/dataloader. │
│ py:630 in __next__                                                           │
│                                                                              │
│    627 │   │   │   if self._sampler_iter is None:                            │
│    628 │   │   │   │   # TODO(https://github.com/pytorch/pytorch/issues/7675 │
│    629 │   │   │   │   self._reset()  # type: ignore[call-arg]               │
│ ❱  630 │   │   │   data = self._next_data()                                  │
│    631 │   │   │   self._num_yielded += 1                                    │
│    632 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and \      │
│    633 │   │   │   │   │   self._IterableDataset_len_called is not None and  │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/torch/utils/data/dataloader. │
│ py:674 in _next_data                                                         │
│                                                                              │
│    671 │                                                                     │
│    672 │   def _next_data(self):                                             │
│    673 │   │   index = self._next_index()  # may raise StopIteration         │
│ ❱  674 │   │   data = self._dataset_fetcher.fetch(index)  # may raise StopIt │
│    675 │   │   if self._pin_memory:                                          │
│    676 │   │   │   data = _utils.pin_memory.pin_memory(data, self._pin_memor │
│    677 │   │   return data                                                   │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/miniforge3/envs/EveryVoice_538_dev │
│ .vg_438-batch-size/lib/python3.10/site-packages/torch/utils/data/_utils/fetc │
│ h.py:54 in fetch                                                             │
│                                                                              │
│   51 │   │   │   │   data = [self.dataset[idx] for idx in possibly_batched_i │
│   52 │   │   else:                                                           │
│   53 │   │   │   data = self.dataset[possibly_batched_index]                 │
│ ❱ 54 │   │   return self.collate_fn(data)                                    │
│   55                                                                         │
│                                                                              │
│ /gpfs/fs5/nrc/nrc-fs1/ict/others/u/tes001/TxT2SPEECH/EveryVoice_538_dev.vg_4 │
│ 38-batch-size/everyvoice/model/feature_prediction/FastSpeech2_lightning/fs2/ │
│ dataset.py:242 in collate_method                                             │
│                                                                              │
│   239 │   │   │   │   │   dur_padded.zero_()                                 │
│   240 │   │   │   │   │   for i in range(len(data[key])):                    │
│   241 │   │   │   │   │   │   dur = data[key][i]                             │
│ ❱ 242 │   │   │   │   │   │   dur_padded[i, : dur.size(0), : dur.size(1)] =  │
│   243 │   │   │   │   │   data[key] = dur_padded                             │
│   244 │   │   │   │   else:                                                  │
│   245 │   │   │   │   │   data[key] = pad_sequence(                          │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: The expanded size of the tensor (48) must match the existing size 
(49) at non-singleton dimension 1.  Target sizes: [638, 48].  Tensor sizes: 
[638, 49]
Loading EveryVoice modules: 100%|██████████| 4/4 [00:04<00:00,  1.08s/it]   
srun: error: ib14gpu-001: task 0: Exited with exit code 1

@roedoejet roedoejet force-pushed the dev.vg/438-batch-size branch from ad8cd7f to c3b6c71 Compare January 22, 2025 21:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Synthesize can only process even multiples of the batch size
2 participants