-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
I performed experiments before but got some errors. I will try again later once the model and dataset are finalized towards the end. Pasting the code below for reference:
Code using ray-tune (need to adjust the code a bit since may give errors):
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.train.lightning import (
RayDDPStrategy,
RayLightningEnvironment,
RayTrainReportCallback,
prepare_trainer,
)
def train_func(config):
dm = ComsoDataModule(batch_size=config['batch_size'])
model = CNN(**config)
trainer = pl.Trainer(
devices="auto",
accelerator="auto",
strategy=RayDDPStrategy(),
callbacks=[RayTrainReportCallback()],
plugins=[RayLightningEnvironment()],
enable_progress_bar=False,
)
trainer = prepare_trainer(trainer)
trainer.fit(model, datamodule=dm)
search_space = {
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64]),
"model_kwargs": model_kwargs, "wd": tune.choice([1e-5]),
"beta1": tune.choice([0.9]), "beta2": tune.choice([0.99]),
"minimum": tune.choice([MIN_VALS]), "maximum": tune.choice([MAX_VALS])
}
from ray.train import RunConfig, ScalingConfig, CheckpointConfig
scaling_config = ScalingConfig(
num_workers=3, use_gpu=False#, resources_per_worker={"CPU": 1, "GPU": 1}
)
run_config = RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="val_loss",
checkpoint_score_order="min",
),
)
from ray.train.torch import TorchTrainer
# Define a TorchTrainer without hyper-parameters for Tuner
ray_trainer = TorchTrainer(
train_func,
scaling_config=scaling_config,
run_config=run_config,
)
def tune_mnist_asha(num_samples=10):
scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)
tuner = tune.Tuner(
ray_trainer,
param_space={"train_loop_config": search_space},
tune_config=tune.TuneConfig(
metric="val_loss",
mode="min",
num_samples=num_samples,
scheduler=scheduler,
),
)
return tuner.fit()
# The maximum training epochs
num_epochs = 5
# Number of sampls from parameter space
num_samples = 10
results = tune_mnist_asha(num_samples=num_samples)
where
class ComsoDataModule(pl.LightningDataModule):
def __init__(self, data_dir, batch_size):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = v2.Compose([
MyRotationTransform(angles=[90, 180, 270]),
v2.ToDtype(torch.float32),#, scale=False),
])
def setup(self, stage: str):
if stage == "fit":
self.train_dataset = CustomImageDataset(f'{base_dir}/train', normalized_cosmo_params_path=f'{base_dir}/train/train_normalized_params.csv', transform=transform)
elif stage == "validate":
self.val_dataset = CustomImageDataset(f'{base_dir}/val', normalized_cosmo_params_path=f'{base_dir}/val/val_normalized_params.csv', transform=None)
elif stage == "test":
self.test_dataset = CustomImageDataset(f'{base_dir}/test', normalized_cosmo_params_path=f'{base_dir}/test/test_normalized_params.csv', transform=None)
def train_dataloader(self):
return DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=3)
def val_dataloader(self):
return DataLoader(dataset=self.val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
def test_dataloader(self):
return DataLoader(dataset=self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
I also required to modify the LightningModule class a bit. Example:
class CNN(pl.LightningModule):
def __init__(self, model_kwargs, lr, wd, beta1, beta2, batch_size, minimum, maximum):
super().__init__()
self.save_hyperparameters()
self.model = model_o3_err(**model_kwargs)
self.example_input_array = next(iter(train_loader))[0]
self.maximum = maximum
self.minimum = minimum
...
...
Metadata
Metadata
Assignees
Labels
No labels