diff --git a/mohou/trainer.py b/mohou/trainer.py index d4449bf..bd483fd 100644 --- a/mohou/trainer.py +++ b/mohou/trainer.py @@ -455,6 +455,7 @@ def train( config: TrainConfig = TrainConfig(), device: Optional[torch.device] = None, is_stoppable: Optional[Callable[[TrainCache], bool]] = None, + num_workers: int = 2, ) -> None: r""" higher-level train function that auto create dataloader from the dataset @@ -462,9 +463,23 @@ def train( dataset_train, dataset_validate = split_with_ratio(dataset, config.valid_data_ratio) - train_loader = DataLoader(dataset=dataset_train, batch_size=config.batch_size, shuffle=True) + if len(dataset_train) < config.batch_size: + message = "dataset size is smaller than batch_size. drop_last is set to False" + logger.warn(change_color_to_yellow(message)) + + # drop last is necessary for batch normalization + train_loader = DataLoader( + dataset=dataset_train, + batch_size=config.batch_size, + shuffle=True, + num_workers=num_workers, + drop_last=True, + ) validate_loader = DataLoader( - dataset=dataset_validate, batch_size=config.batch_size, shuffle=True + dataset=dataset_validate, + batch_size=config.batch_size, + shuffle=True, + num_workers=num_workers, ) train_lower( project_path,