Skip to content

Hyperparameter tuning #12

@Yash-10

Description

@Yash-10

I performed experiments before but got some errors. I will try again later once the model and dataset are finalized towards the end. Pasting the code below for reference:

Code using ray-tune (need to adjust the code a bit since may give errors):

from ray import tune
from ray.tune.schedulers import ASHAScheduler

from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)


def train_func(config):
    dm = ComsoDataModule(batch_size=config['batch_size'])
    model = CNN(**config)

    trainer = pl.Trainer(
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        callbacks=[RayTrainReportCallback()],
        plugins=[RayLightningEnvironment()],
        enable_progress_bar=False,
    )
    trainer = prepare_trainer(trainer)
    trainer.fit(model, datamodule=dm)

search_space = {
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([32, 64]),
    "model_kwargs": model_kwargs, "wd": tune.choice([1e-5]),
    "beta1": tune.choice([0.9]), "beta2": tune.choice([0.99]),
    "minimum": tune.choice([MIN_VALS]), "maximum": tune.choice([MAX_VALS])
}

from ray.train import RunConfig, ScalingConfig, CheckpointConfig

scaling_config = ScalingConfig(
    num_workers=3, use_gpu=False#, resources_per_worker={"CPU": 1, "GPU": 1}
)

run_config = RunConfig(
    checkpoint_config=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="val_loss",
        checkpoint_score_order="min",
    ),
)

from ray.train.torch import TorchTrainer

# Define a TorchTrainer without hyper-parameters for Tuner
ray_trainer = TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    run_config=run_config,
)

def tune_mnist_asha(num_samples=10):
    scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

    tuner = tune.Tuner(
        ray_trainer,
        param_space={"train_loop_config": search_space},
        tune_config=tune.TuneConfig(
            metric="val_loss",
            mode="min",
            num_samples=num_samples,
            scheduler=scheduler,
        ),
    )
    return tuner.fit()

# The maximum training epochs
num_epochs = 5

# Number of sampls from parameter space
num_samples = 10

results = tune_mnist_asha(num_samples=num_samples)

where

class ComsoDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = v2.Compose([
            MyRotationTransform(angles=[90, 180, 270]),
            v2.ToDtype(torch.float32),#, scale=False),
        ])

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = CustomImageDataset(f'{base_dir}/train', normalized_cosmo_params_path=f'{base_dir}/train/train_normalized_params.csv', transform=transform)
        elif stage == "validate":
            self.val_dataset = CustomImageDataset(f'{base_dir}/val', normalized_cosmo_params_path=f'{base_dir}/val/val_normalized_params.csv', transform=None)
        elif stage == "test":
            self.test_dataset = CustomImageDataset(f'{base_dir}/test', normalized_cosmo_params_path=f'{base_dir}/test/test_normalized_params.csv', transform=None)

    def train_dataloader(self):
        return DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=3)

    def val_dataloader(self):
        return DataLoader(dataset=self.val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)

    def test_dataloader(self):
        return DataLoader(dataset=self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=3)

I also required to modify the LightningModule class a bit. Example:

class CNN(pl.LightningModule):

    def __init__(self, model_kwargs, lr, wd, beta1, beta2, batch_size, minimum, maximum):
        super().__init__()
        self.save_hyperparameters()
        self.model = model_o3_err(**model_kwargs)
        self.example_input_array = next(iter(train_loader))[0]

        self.maximum = maximum
        self.minimum = minimum
...
...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions