Skip to content

Commit

Permalink
speed up fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Nov 30, 2024
1 parent d43f6a0 commit 4ca7bbb
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def get_trainer(
save_folder: Optional[str] = None,
save_filename: str = 'ba{batch}-rank{rank}.pt',
save_overwrite: bool = False,
num_features: int = 4,
num_classes: int = 2,
num_features: int = 4, # Reduced from default
num_classes: int = 2, # Reduced from default
load_path: Optional[str] = None,
autoresume: bool = False,
run_name: Optional[str] = None,
Expand All @@ -111,11 +111,11 @@ def get_trainer(
val_metrics=val_metrics,
)
model.module.to(model_init_device)
dataset = RandomClassificationDataset(shape=(num_features,), num_classes=num_classes, size=128)
dataset = RandomClassificationDataset(shape=(num_features,), num_classes=num_classes, size=32)
dataloader = DataLoader(
dataset,
sampler=dist.get_sampler(dataset),
batch_size=8,
batch_size=2,
)
if optimizer == 'adam':
optim = torch.optim.Adam(params=model.parameters())
Expand Down

0 comments on commit 4ca7bbb

Please sign in to comment.