Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve weather dataloader #136

Open
wants to merge 1 commit into
base: OpenSTL-Lightning
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 79 additions & 100 deletions openstl/datasets/dataloader_weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
from torch.utils.data import Dataset
from openstl.datasets.utils import create_loader
import tqdm

try:
import xarray as xr
Expand Down Expand Up @@ -42,23 +43,33 @@ def xyz2latlon(x, y, z):


data_map = {
'g': 'geopotential',
'z': 'geopotential_500',
'z': 'geopotential',
't': 'temperature',
't850': 'temperature_850',
'tp': 'total_precipitation',
't2m': '2m_temperature',
'r': 'relative_humidity',
's': 'specific_humidity',
'u10': '10m_u_component_of_wind',
'u': 'u_component_of_wind',
'v10': '10m_v_component_of_wind',
'v': 'v_component_of_wind',
'tcc': 'total_cloud_cover',
"lsm": "constants",
"o": "constants",
"l": "constants",
}

mv_data_map = {
**dict.fromkeys(['mv', 'mv4'], ['r', 't', 'u', 'v']),
'mv5': ['g', 'r', 't', 'u', 'v'],
'mv5': ['z', 'r', 't', 'u', 'v'],
'uv10': ['u10', 'v10'],
'mv12': ['lsm', 'o', 't2m', 'u10', 'v10', 'l', 'z', 'u', 'v', 't', 'r', 's']
}

data_keys_map = {
'o': 'orography',
'l': 'lat2d',
's': 'q'
}


Expand All @@ -67,132 +78,99 @@ class WeatherBenchDataset(Dataset):

Args:
data_root (str): Path to the dataset.
data_name (str): Name of the weather modality in Wheather Bench.
data_name (str|list): Name(s) of the weather modality in Wheather Bench.
training_time (list): The arrange of years for training.
idx_in (list): The list of input indices.
idx_out (list): The list of output indices to predict.
step (int): Sampling step in the time dimension.
level (int): Used level in the multi-variant version.
level (int|list|"all"): Level(s) to use.
data_split (str): The resolution (degree) of Wheather Bench splits.
use_augment (bool): Whether to use augmentations (defaults to False).
"""

def __init__(self, data_root, data_name, training_time,
idx_in, idx_out, step=1, level=1, data_split='5_625',
idx_in, idx_out, step=1, levels=['50'], data_split='5_625',
mean=None, std=None,
transform_data=None, transform_labels=None, use_augment=False):
super().__init__()
self.data_root = data_root
self.data_name = data_name
self.data_split = data_split
self.training_time = training_time
self.idx_in = np.array(idx_in)
self.idx_out = np.array(idx_out)
self.step = step
self.level = level
self.data = None
self.mean = mean
self.std = std
self.transform_data = transform_data
self.transform_labels = transform_labels
self.use_augment = use_augment
assert isinstance(level, (int, list))

self.time = None
self.time_size = self.training_time
shape = int(32 * 5.625 / float(data_split.replace('_', '.')))
self.shape = (shape, shape * 2)

if isinstance(data_name, list):
data_name = data_name[0]
if 'mv' in data_name: # multi-variant version
self.data_name = mv_data_map[data_name]
self.data, self.mean, self.std = [], [], []
for name in self.data_name:
data, mean, std = self._load_data_xarray(data_name=name, single_variant=False)
self.data.append(data)
self.mean.append(mean)
self.std.append(std)
self.data = np.concatenate(self.data, axis=1)
self.mean = np.concatenate(self.mean, axis=1)
self.std = np.concatenate(self.std, axis=1)
else: # single variant
self.data_name = data_name
self.data, mean, std = self._load_data_xarray(data_name, single_variant=True)
if self.mean is None:
self.mean, self.std = mean, std
self.data, self.mean, self.std = [], [], []

if levels == 'all':
levels = ['50', '250', '500', '600', '700', '850', '925']
levels = levels if isinstance(levels, list) else [levels]
levels = [int(level) for level in levels]
if isinstance(data_name, str) and data_name in mv_data_map:
data_names = mv_data_map[data_name]
else:
data_names = data_name if isinstance(data_name, list) else [data_name]

for name in tqdm.tqdm(data_names):
data, mean, std = self._load_data_xarray(data_name=name, levels=levels)
self.data.append(data)
self.mean.append(mean)
self.std.append(std)

for i, data in enumerate(self.data):
if data.shape[0] != self.time_size:
self.data[i] = data.repeat(self.time_size, axis=0)

self.data = np.concatenate(self.data, axis=1)
self.mean = np.concatenate(self.mean, axis=1)
self.std = np.concatenate(self.std, axis=1)

self.valid_idx = np.array(
range(-idx_in[0], self.data.shape[0]-idx_out[-1]-1))

def _load_data_xarray(self, data_name, single_variant=True):
def _load_data_xarray(self, data_name, levels):
"""Loading full data with xarray"""
if data_name != 'uv10':
try:
dataset = xr.open_mfdataset(self.data_root+'/{}/{}*.nc'.format(
data_map[data_name], data_map[data_name]), combine='by_coords')
except (AttributeError, ValueError):
assert False and 'Please install xarray and its dependency (e.g., netcdf4), ' \
'pip install xarray==0.19.0,' \
'pip install netcdf4 h5netcdf dask'
except OSError:
print("OSError: Invalid path {}/{}/*.nc".format(self.data_root, data_map[data_name]))
assert False
try:
dataset = xr.open_mfdataset(self.data_root+'/{}/{}*.nc'.format(
data_map[data_name], data_map[data_name]), combine='by_coords')
except (AttributeError, ValueError):
assert False and 'Please install xarray and its dependency (e.g., netcdf4), ' \
'pip install xarray==0.19.0,' \
'pip install netcdf4 h5netcdf dask'
except OSError:
print("OSError: Invalid path {}/{}/*.nc".format(self.data_root, data_map[data_name]))
assert False

if 'time' not in dataset.indexes:
dataset = dataset.expand_dims(dim={"time": 1}, axis=0)
else:
dataset = dataset.sel(time=slice(*self.training_time))
dataset = dataset.isel(time=slice(None, -1, self.step))
if self.time is None and single_variant:
self.week = dataset['time.week']
self.month = dataset['time.month']
self.year = dataset['time.year']
self.time = np.stack(
[self.week, self.month, self.year], axis=1)
lon, lat = np.meshgrid(
(dataset.lon-180) * d2r, dataset.lat*d2r)
x, y, z = latlon2xyz(lat, lon)
self.V = np.stack([x, y, z]).reshape(3, self.shape[0]*self.shape[1]).T
if not single_variant and isinstance(self.level, list):
dataset = dataset.sel(level=np.array(self.level))
data = dataset.get(data_name).values[:, np.newaxis, :, :]

elif data_name == 'uv10':
input_datasets = []
for key in ['u10', 'v10']:
try:
dataset = xr.open_mfdataset(self.data_root+'/{}/{}*.nc'.format(
data_map[key], data_map[key]), combine='by_coords')
except (AttributeError, ValueError):
assert False and 'Please install xarray and its dependency (e.g., netcdf4), ' \
'pip install xarray==0.19.0,' \
'pip install netcdf4 h5netcdf dask'
except OSError:
print("OSError: Invalid path {}/{}/*.nc".format(self.data_root, data_map[key]))
assert False
dataset = dataset.sel(time=slice(*self.training_time))
dataset = dataset.isel(time=slice(None, -1, self.step))
if self.time is None and single_variant:
self.week = dataset['time.week']
self.month = dataset['time.month']
self.year = dataset['time.year']
self.time = np.stack(
[self.week, self.month, self.year], axis=1)
lon, lat = np.meshgrid(
(dataset.lon-180) * d2r, dataset.lat*d2r)
x, y, z = latlon2xyz(lat, lon)
self.V = np.stack([x, y, z]).reshape(3, self.shape[0]*self.shape[1]).T
input_datasets.append(dataset.get(key).values[:, np.newaxis, :, :])
data = np.concatenate(input_datasets, axis=1)

# uv10
if len(data.shape) == 5:
data = data.squeeze(1)
# humidity
if data_name == 'r' and single_variant:
data = data[:, -1:, ...]
# multi-variant level
if not single_variant and isinstance(self.level, int):
data = data[:, -self.level:, ...]

mean = data.mean(axis=(0, 2, 3)).reshape(1, data.shape[1], 1, 1)
std = data.std(axis=(0, 2, 3)).reshape(1, data.shape[1], 1, 1)
self.time_size = dataset.dims['time']

if 'level' not in dataset.indexes:
dataset = dataset.expand_dims(dim={"level": 1}, axis=1)
else:
dataset = dataset.sel(level=np.array(levels))

if data_name in data_keys_map:
data = dataset.get(data_keys_map[data_name]).values
else:
data = dataset.get(data_name).values

mean = data.mean().reshape(1, 1, 1, 1)
std = data.std().reshape(1, 1, 1, 1)
# mean = dataset.mean('time').mean(('lat', 'lon')).compute()[data_name].values
# std = dataset.std('time').mean(('lat', 'lon')).compute()[data_name].values
data = (data - mean) / std
Expand Down Expand Up @@ -240,34 +218,35 @@ def load_data(batch_size,
idx_in=[-11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0],
idx_out=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
step=1,
level=1,
levels=['50'],
distributed=False, use_augment=False, use_prefetcher=False, drop_last=False,
**kwargs):

assert data_split in ['5_625', '2_8125', '1_40625']
_dataroot = osp.join(data_root, f'weather_{data_split}deg')
weather_dataroot = _dataroot if osp.exists(_dataroot) else osp.join(data_root, 'weather')
for suffix in [f'weather_{data_split}deg', f'weather', f'{data_split}deg']:
if osp.exists(osp.join(data_root, suffix)):
weather_dataroot = osp.join(data_root, suffix)

train_set = WeatherBenchDataset(data_root=weather_dataroot,
data_name=data_name, data_split=data_split,
training_time=train_time,
idx_in=idx_in,
idx_out=idx_out,
step=step, level=level, use_augment=use_augment)
step=step, levels=levels, use_augment=use_augment)
vali_set = WeatherBenchDataset(weather_dataroot,
data_name=data_name, data_split=data_split,
training_time=val_time,
idx_in=idx_in,
idx_out=idx_out,
step=step, level=level, use_augment=False,
step=step, levels=levels, use_augment=False,
mean=train_set.mean,
std=train_set.std)
test_set = WeatherBenchDataset(weather_dataroot,
data_name, data_split=data_split,
training_time=test_time,
idx_in=idx_in,
idx_out=idx_out,
step=step, level=level, use_augment=False,
step=step, levels=levels, use_augment=False,
mean=train_set.mean,
std=train_set.std)

Expand Down