From 3c44c362cac7325af818253fbb36d75d2a369d9f Mon Sep 17 00:00:00 2001 From: nkamil Date: Thu, 4 Apr 2024 16:20:38 +0000 Subject: [PATCH 1/4] move to lightning --- environment.yml | 2 +- openstl/api/exp.py | 6 +++--- openstl/datasets/base_data.py | 4 ++-- openstl/methods/base_method.py | 4 ++-- requirements/runtime.txt | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/environment.yml b/environment.yml index 8a3e3ae9..debb5e78 100644 --- a/environment.yml +++ b/environment.yml @@ -8,7 +8,7 @@ dependencies: - hickle - pip - python<=3.10.8 - - pytorch-lightning<=1.9.5 + - lightning - pytorch<=2.1.1 - xarray==0.19.0 - pip: diff --git a/openstl/api/exp.py b/openstl/api/exp.py index 77f233b5..a0a9103b 100644 --- a/openstl/api/exp.py +++ b/openstl/api/exp.py @@ -12,8 +12,8 @@ from openstl.utils import (get_dataset, measure_throughput, SetupCallback, EpochEndCallback, BestCheckpointCallback) import argparse -from pytorch_lightning import seed_everything, Trainer -import pytorch_lightning.callbacks as plc +from lightning import seed_everything, Trainer +import lightning.pytorch.callbacks as lc class BaseExperiment(object): @@ -77,7 +77,7 @@ def _load_callbacks(self, args, save_dir, ckpt_dir): callbacks = [setup_callback, ckpt_callback, epochend_callback] if args.sched: - callbacks.append(plc.LearningRateMonitor(logging_interval=None)) + callbacks.append(lc.LearningRateMonitor(logging_interval=None)) return callbacks, save_dir def _get_data(self, dataloaders=None): diff --git a/openstl/datasets/base_data.py b/openstl/datasets/base_data.py index 2fdd0f75..ea522104 100644 --- a/openstl/datasets/base_data.py +++ b/openstl/datasets/base_data.py @@ -1,7 +1,7 @@ -import pytorch_lightning as pl +import lightning as l -class BaseDataModule(pl.LightningDataModule): +class BaseDataModule(l.LightningDataModule): def __init__(self, train_loader, valid_loader, test_loader): super().__init__() self.train_loader = train_loader diff --git a/openstl/methods/base_method.py b/openstl/methods/base_method.py index 78a251b4..d4704dd9 100644 --- a/openstl/methods/base_method.py +++ b/openstl/methods/base_method.py @@ -1,13 +1,13 @@ import numpy as np import torch.nn as nn import os.path as osp -import pytorch_lightning as pl +import lightning as l from openstl.utils import print_log, check_dir from openstl.core import get_optim_scheduler from openstl.core import metric -class Base_method(pl.LightningModule): +class Base_method(l.LightningModule): def __init__(self, **args): super().__init__() diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 30be7229..e04f29c8 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -17,5 +17,5 @@ scikit-learn timm tqdm xarray -pytorch-lightning +lightning PyWavelets From e9269f7edbd6be82bd49eb432995f5b69becf6cb Mon Sep 17 00:00:00 2001 From: nkamil Date: Thu, 4 Apr 2024 16:37:33 +0000 Subject: [PATCH 2/4] fixes --- openstl/utils/callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openstl/utils/callbacks.py b/openstl/utils/callbacks.py index 1d45a543..93620a7d 100644 --- a/openstl/utils/callbacks.py +++ b/openstl/utils/callbacks.py @@ -2,7 +2,7 @@ import shutil import logging import os.path as osp -from pytorch_lightning.callbacks import Callback, ModelCheckpoint +from lightning.pytorch.callbacks import Callback, ModelCheckpoint from .main_utils import check_dir, collect_env, print_log, output_namespace From be64b4b2f9682791f48196c3f3b985f53d5150d9 Mon Sep 17 00:00:00 2001 From: Kamil <50270585+NinevskiyK@users.noreply.github.com> Date: Sun, 5 May 2024 10:14:49 +0300 Subject: [PATCH 3/4] Update runtime.txt --- requirements/runtime.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index e04f29c8..aeef5ef1 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -17,5 +17,5 @@ scikit-learn timm tqdm xarray -lightning +lightning==2.2.1 PyWavelets From 673a9c22eddea559d4e8b39da8d049b5261ed5b0 Mon Sep 17 00:00:00 2001 From: nkamil Date: Mon, 13 May 2024 17:14:31 +0000 Subject: [PATCH 4/4] improve weather dataloader --- openstl/datasets/dataloader_weather.py | 179 +++++++++++-------------- 1 file changed, 79 insertions(+), 100 deletions(-) diff --git a/openstl/datasets/dataloader_weather.py b/openstl/datasets/dataloader_weather.py index 8d179af4..89870f08 100644 --- a/openstl/datasets/dataloader_weather.py +++ b/openstl/datasets/dataloader_weather.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from torch.utils.data import Dataset from openstl.datasets.utils import create_loader +import tqdm try: import xarray as xr @@ -42,23 +43,33 @@ def xyz2latlon(x, y, z): data_map = { - 'g': 'geopotential', - 'z': 'geopotential_500', + 'z': 'geopotential', 't': 'temperature', - 't850': 'temperature_850', 'tp': 'total_precipitation', 't2m': '2m_temperature', 'r': 'relative_humidity', + 's': 'specific_humidity', 'u10': '10m_u_component_of_wind', 'u': 'u_component_of_wind', 'v10': '10m_v_component_of_wind', 'v': 'v_component_of_wind', 'tcc': 'total_cloud_cover', + "lsm": "constants", + "o": "constants", + "l": "constants", } mv_data_map = { **dict.fromkeys(['mv', 'mv4'], ['r', 't', 'u', 'v']), - 'mv5': ['g', 'r', 't', 'u', 'v'], + 'mv5': ['z', 'r', 't', 'u', 'v'], + 'uv10': ['u10', 'v10'], + 'mv12': ['lsm', 'o', 't2m', 'u10', 'v10', 'l', 'z', 'u', 'v', 't', 'r', 's'] +} + +data_keys_map = { + 'o': 'orography', + 'l': 'lat2d', + 's': 'q' } @@ -67,132 +78,99 @@ class WeatherBenchDataset(Dataset): Args: data_root (str): Path to the dataset. - data_name (str): Name of the weather modality in Wheather Bench. + data_name (str|list): Name(s) of the weather modality in Wheather Bench. training_time (list): The arrange of years for training. idx_in (list): The list of input indices. idx_out (list): The list of output indices to predict. step (int): Sampling step in the time dimension. - level (int): Used level in the multi-variant version. + level (int|list|"all"): Level(s) to use. data_split (str): The resolution (degree) of Wheather Bench splits. use_augment (bool): Whether to use augmentations (defaults to False). """ def __init__(self, data_root, data_name, training_time, - idx_in, idx_out, step=1, level=1, data_split='5_625', + idx_in, idx_out, step=1, levels=['50'], data_split='5_625', mean=None, std=None, transform_data=None, transform_labels=None, use_augment=False): super().__init__() self.data_root = data_root - self.data_name = data_name self.data_split = data_split self.training_time = training_time self.idx_in = np.array(idx_in) self.idx_out = np.array(idx_out) self.step = step - self.level = level self.data = None self.mean = mean self.std = std self.transform_data = transform_data self.transform_labels = transform_labels self.use_augment = use_augment - assert isinstance(level, (int, list)) self.time = None + self.time_size = self.training_time shape = int(32 * 5.625 / float(data_split.replace('_', '.'))) self.shape = (shape, shape * 2) - if isinstance(data_name, list): - data_name = data_name[0] - if 'mv' in data_name: # multi-variant version - self.data_name = mv_data_map[data_name] - self.data, self.mean, self.std = [], [], [] - for name in self.data_name: - data, mean, std = self._load_data_xarray(data_name=name, single_variant=False) - self.data.append(data) - self.mean.append(mean) - self.std.append(std) - self.data = np.concatenate(self.data, axis=1) - self.mean = np.concatenate(self.mean, axis=1) - self.std = np.concatenate(self.std, axis=1) - else: # single variant - self.data_name = data_name - self.data, mean, std = self._load_data_xarray(data_name, single_variant=True) - if self.mean is None: - self.mean, self.std = mean, std + self.data, self.mean, self.std = [], [], [] + + if levels == 'all': + levels = ['50', '250', '500', '600', '700', '850', '925'] + levels = levels if isinstance(levels, list) else [levels] + levels = [int(level) for level in levels] + if isinstance(data_name, str) and data_name in mv_data_map: + data_names = mv_data_map[data_name] + else: + data_names = data_name if isinstance(data_name, list) else [data_name] + + for name in tqdm.tqdm(data_names): + data, mean, std = self._load_data_xarray(data_name=name, levels=levels) + self.data.append(data) + self.mean.append(mean) + self.std.append(std) + + for i, data in enumerate(self.data): + if data.shape[0] != self.time_size: + self.data[i] = data.repeat(self.time_size, axis=0) + + self.data = np.concatenate(self.data, axis=1) + self.mean = np.concatenate(self.mean, axis=1) + self.std = np.concatenate(self.std, axis=1) self.valid_idx = np.array( range(-idx_in[0], self.data.shape[0]-idx_out[-1]-1)) - def _load_data_xarray(self, data_name, single_variant=True): + def _load_data_xarray(self, data_name, levels): """Loading full data with xarray""" - if data_name != 'uv10': - try: - dataset = xr.open_mfdataset(self.data_root+'/{}/{}*.nc'.format( - data_map[data_name], data_map[data_name]), combine='by_coords') - except (AttributeError, ValueError): - assert False and 'Please install xarray and its dependency (e.g., netcdf4), ' \ - 'pip install xarray==0.19.0,' \ - 'pip install netcdf4 h5netcdf dask' - except OSError: - print("OSError: Invalid path {}/{}/*.nc".format(self.data_root, data_map[data_name])) - assert False + try: + dataset = xr.open_mfdataset(self.data_root+'/{}/{}*.nc'.format( + data_map[data_name], data_map[data_name]), combine='by_coords') + except (AttributeError, ValueError): + assert False and 'Please install xarray and its dependency (e.g., netcdf4), ' \ + 'pip install xarray==0.19.0,' \ + 'pip install netcdf4 h5netcdf dask' + except OSError: + print("OSError: Invalid path {}/{}/*.nc".format(self.data_root, data_map[data_name])) + assert False + + if 'time' not in dataset.indexes: + dataset = dataset.expand_dims(dim={"time": 1}, axis=0) + else: dataset = dataset.sel(time=slice(*self.training_time)) dataset = dataset.isel(time=slice(None, -1, self.step)) - if self.time is None and single_variant: - self.week = dataset['time.week'] - self.month = dataset['time.month'] - self.year = dataset['time.year'] - self.time = np.stack( - [self.week, self.month, self.year], axis=1) - lon, lat = np.meshgrid( - (dataset.lon-180) * d2r, dataset.lat*d2r) - x, y, z = latlon2xyz(lat, lon) - self.V = np.stack([x, y, z]).reshape(3, self.shape[0]*self.shape[1]).T - if not single_variant and isinstance(self.level, list): - dataset = dataset.sel(level=np.array(self.level)) - data = dataset.get(data_name).values[:, np.newaxis, :, :] - - elif data_name == 'uv10': - input_datasets = [] - for key in ['u10', 'v10']: - try: - dataset = xr.open_mfdataset(self.data_root+'/{}/{}*.nc'.format( - data_map[key], data_map[key]), combine='by_coords') - except (AttributeError, ValueError): - assert False and 'Please install xarray and its dependency (e.g., netcdf4), ' \ - 'pip install xarray==0.19.0,' \ - 'pip install netcdf4 h5netcdf dask' - except OSError: - print("OSError: Invalid path {}/{}/*.nc".format(self.data_root, data_map[key])) - assert False - dataset = dataset.sel(time=slice(*self.training_time)) - dataset = dataset.isel(time=slice(None, -1, self.step)) - if self.time is None and single_variant: - self.week = dataset['time.week'] - self.month = dataset['time.month'] - self.year = dataset['time.year'] - self.time = np.stack( - [self.week, self.month, self.year], axis=1) - lon, lat = np.meshgrid( - (dataset.lon-180) * d2r, dataset.lat*d2r) - x, y, z = latlon2xyz(lat, lon) - self.V = np.stack([x, y, z]).reshape(3, self.shape[0]*self.shape[1]).T - input_datasets.append(dataset.get(key).values[:, np.newaxis, :, :]) - data = np.concatenate(input_datasets, axis=1) - - # uv10 - if len(data.shape) == 5: - data = data.squeeze(1) - # humidity - if data_name == 'r' and single_variant: - data = data[:, -1:, ...] - # multi-variant level - if not single_variant and isinstance(self.level, int): - data = data[:, -self.level:, ...] - - mean = data.mean(axis=(0, 2, 3)).reshape(1, data.shape[1], 1, 1) - std = data.std(axis=(0, 2, 3)).reshape(1, data.shape[1], 1, 1) + self.time_size = dataset.dims['time'] + + if 'level' not in dataset.indexes: + dataset = dataset.expand_dims(dim={"level": 1}, axis=1) + else: + dataset = dataset.sel(level=np.array(levels)) + + if data_name in data_keys_map: + data = dataset.get(data_keys_map[data_name]).values + else: + data = dataset.get(data_name).values + + mean = data.mean().reshape(1, 1, 1, 1) + std = data.std().reshape(1, 1, 1, 1) # mean = dataset.mean('time').mean(('lat', 'lon')).compute()[data_name].values # std = dataset.std('time').mean(('lat', 'lon')).compute()[data_name].values data = (data - mean) / std @@ -240,26 +218,27 @@ def load_data(batch_size, idx_in=[-11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0], idx_out=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], step=1, - level=1, + levels=['50'], distributed=False, use_augment=False, use_prefetcher=False, drop_last=False, **kwargs): assert data_split in ['5_625', '2_8125', '1_40625'] - _dataroot = osp.join(data_root, f'weather_{data_split}deg') - weather_dataroot = _dataroot if osp.exists(_dataroot) else osp.join(data_root, 'weather') + for suffix in [f'weather_{data_split}deg', f'weather', f'{data_split}deg']: + if osp.exists(osp.join(data_root, suffix)): + weather_dataroot = osp.join(data_root, suffix) train_set = WeatherBenchDataset(data_root=weather_dataroot, data_name=data_name, data_split=data_split, training_time=train_time, idx_in=idx_in, idx_out=idx_out, - step=step, level=level, use_augment=use_augment) + step=step, levels=levels, use_augment=use_augment) vali_set = WeatherBenchDataset(weather_dataroot, data_name=data_name, data_split=data_split, training_time=val_time, idx_in=idx_in, idx_out=idx_out, - step=step, level=level, use_augment=False, + step=step, levels=levels, use_augment=False, mean=train_set.mean, std=train_set.std) test_set = WeatherBenchDataset(weather_dataroot, @@ -267,7 +246,7 @@ def load_data(batch_size, training_time=test_time, idx_in=idx_in, idx_out=idx_out, - step=step, level=level, use_augment=False, + step=step, levels=levels, use_augment=False, mean=train_set.mean, std=train_set.std)