diff --git a/mohou/trainer.py b/mohou/trainer.py index d4449bf..3fb0e4e 100644 --- a/mohou/trainer.py +++ b/mohou/trainer.py @@ -455,6 +455,7 @@ def train( config: TrainConfig = TrainConfig(), device: Optional[torch.device] = None, is_stoppable: Optional[Callable[[TrainCache], bool]] = None, + num_workers: int = 2, ) -> None: r""" higher-level train function that auto create dataloader from the dataset @@ -462,9 +463,26 @@ def train( dataset_train, dataset_validate = split_with_ratio(dataset, config.valid_data_ratio) - train_loader = DataLoader(dataset=dataset_train, batch_size=config.batch_size, shuffle=True) + drop_last = True + if len(dataset_train) < config.batch_size: + message = "dataset size is smaller than batch_size. drop_last is set to False" + logger.warn(change_color_to_yellow(message)) + drop_last = False + + # drop last is necessary for batch normalization + train_loader = DataLoader( + dataset=dataset_train, + batch_size=config.batch_size, + shuffle=True, + num_workers=num_workers, + drop_last=drop_last, + ) validate_loader = DataLoader( - dataset=dataset_validate, batch_size=config.batch_size, shuffle=True + dataset=dataset_validate, + batch_size=config.batch_size, + shuffle=True, + num_workers=num_workers, + drop_last=False, # drop_last is not necessary for validation as batch normalization is not used ) train_lower( project_path,