Skip to content

Commit

Permalink
fix/feat(trainer): drop_last = False and num_workers = 2
Browse files Browse the repository at this point in the history
  • Loading branch information
HiroIshida committed Feb 22, 2024
1 parent 1929a0a commit 335eb67
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions mohou/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,16 +455,34 @@ def train(
config: TrainConfig = TrainConfig(),
device: Optional[torch.device] = None,
is_stoppable: Optional[Callable[[TrainCache], bool]] = None,
num_workers: int = 2,
) -> None:
r"""
higher-level train function that auto create dataloader from the dataset
"""

dataset_train, dataset_validate = split_with_ratio(dataset, config.valid_data_ratio)

train_loader = DataLoader(dataset=dataset_train, batch_size=config.batch_size, shuffle=True)
drop_last = True
if len(dataset_train) < config.batch_size:
message = "dataset size is smaller than batch_size. drop_last is set to False"
logger.warn(change_color_to_yellow(message))
drop_last = False

# drop last is necessary for batch normalization
train_loader = DataLoader(
dataset=dataset_train,
batch_size=config.batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=drop_last,
)
validate_loader = DataLoader(
dataset=dataset_validate, batch_size=config.batch_size, shuffle=True
dataset=dataset_validate,
batch_size=config.batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=False, # drop_last is not necessary for validation as batch normalization is not used
)
train_lower(
project_path,
Expand Down

0 comments on commit 335eb67

Please sign in to comment.