Skip to content

Commit

Permalink
update sampler
Browse files Browse the repository at this point in the history
Signed-off-by: stevehuang52 <[email protected]>
  • Loading branch information
stevehuang52 committed Dec 3, 2024
1 parent 988e585 commit 7a7052d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/speechlm/conf/salm_fc_linear.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ data:
# Notably, the data weights are controlled by either bucketing_weights
# or concat_sampling_probabilities depending on the dataset type (tar and
# non-tar).
# See audio_text_qa_dataset.py for details.
concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random'
context_key: ${data.common.context_key}
answer_key: ${data.common.answer_key}
Expand All @@ -59,6 +58,7 @@ data:
max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset
min_duration: 0.1
# tarred datasets
is_concat: false
is_tarred: false
tarred_audio_filepaths: null
shuffle_n: 2048
Expand Down
6 changes: 5 additions & 1 deletion nemo/collections/speechlm/data/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterableDataset
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging

Expand All @@ -36,6 +36,10 @@ def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0
logging.info(f"Dataset {dataloader.dataset} does not have __len__ method. Skipping Megatron data sampler.")
return dataloader

if isinstance(dataloader.dataset, IterableDataset):
logging.info(f"Dataset {dataloader.dataset} is an IterableDataset. Skipping Megatron data sampler.")
return dataloader

from megatron.core import parallel_state

from nemo.lightning.data import add_megatron_sampler
Expand Down

0 comments on commit 7a7052d

Please sign in to comment.