Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions climatem/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,20 @@ def __init__(
n_per_col: int = 2, # square grid, equivalent of lat/lon
difficulty: str = "easy", # easy, med_easy, med_hard, hard: difficulty of the graph
seasonality: bool = False, # Seasonality in synthetic data
periods: List[float] = [365, 182.5, 60], # Periods of the seasonality in days
amplitudes: List[float] = [0.06, 0.02, 0.01], # Amplitudes of the seasonality
phases: List[float] = [0.0, 0.7853981634, 1.5707963268], # Phases of the seasonality in radians
yearly_jitter_amp: float = 0.05, # Amplitude of the yearly jitter
yearly_jitter_phase: float = 0.10, # Phase of the yearly
overlap: bool = False, # Modes overlap
is_forced: bool = False, # Forcings in synthetic data
f_1: int = 1,
f_2: int = 2,
f_time_1: int = 4000,
f_time_2: int = 8000,
ramp_type: str = "linear",
linearity: str = "linear",
poly_degrees: List[int] = [2],
plot_original_data: bool = True,
):
self.time_len = time_len
Expand All @@ -298,8 +310,20 @@ def __init__(
self.n_per_col = n_per_col
self.difficulty = difficulty
self.seasonality = seasonality
self.periods = periods
self.amplitudes = amplitudes
self.phases = phases
self.yearly_jitter_amp = yearly_jitter_amp
self.yearly_jitter_phase = yearly_jitter_phase
self.overlap = overlap
self.is_forced = is_forced
self.f_1 = f_1
self.f_2 = f_2
self.f_time_1 = f_time_1
self.f_time_2 = f_time_2
self.ramp_type = ramp_type
self.linearity = linearity
self.poly_degrees = poly_degrees
self.plot_original_data = plot_original_data


Expand Down
11 changes: 9 additions & 2 deletions climatem/data_loader/causal_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# import relevant data loading modules
from climatem.data_loader.climate_datamodule import ClimateDataModule
from climatem.data_loader.cmip6_dataset import CMIP6Dataset
from climatem.data_loader.input4mip_dataset import Input4MipsDataset
from climatem.data_loader.era5_dataset import ERA5Dataset
from climatem.data_loader.input4mip_dataset import Input4MipsDataset
from climatem.data_loader.savar_dataset import SavarDataset


Expand Down Expand Up @@ -99,6 +99,13 @@ def setup(self, stage: Optional[str] = None):
seasonality=self.hparams.seasonality,
overlap=self.hparams.overlap,
is_forced=self.hparams.is_forced,
f_1=self.hparams.f_1,
f_2=self.hparams.f_2,
f_time_1=self.hparams.f_time_1,
f_time_2=self.hparams.f_time_2,
ramp_type=self.hparams.ramp_type,
linearity=self.hparams.linearity,
poly_degrees=self.hparams.poly_degrees,
plot_original_data=self.hparams.plot_original_data,
)
elif (
Expand Down Expand Up @@ -213,4 +220,4 @@ def setup(self, stage: Optional[str] = None):
else OPENBURNING_MODEL_MAPPING["other"]
)
for test_model in self.hparams.test_models
}
}
4 changes: 3 additions & 1 deletion climatem/data_loader/climate_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(
seasonality: bool = False,
overlap: bool = False,
is_forced: bool = False,
linearity: str = "linear",
poly_degrees: List[int] = [2],
plot_original_data: bool = True,
):
"""
Expand Down Expand Up @@ -263,4 +265,4 @@ def predict_dataloader(self) -> EVAL_DATALOADERS:
if self._data_val is not None
else None
)
]
]
1 change: 1 addition & 0 deletions climatem/data_loader/cmip6_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

# from climatem.plotting.plot_data import plot_species, plot_species_anomaly
from climatem.utils import get_logger

from .climate_dataset import ClimateDataset

log = get_logger()
Expand Down
3 changes: 1 addition & 2 deletions climatem/data_loader/input4mip_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# NOTE: as of 14th Oct, I am also trying to get this to work for multiple variables.

import glob
import os
import zipfile
from pathlib import Path
from typing import List, Optional, Tuple, Union

Expand All @@ -24,6 +22,7 @@
# input4mips data set: same per model
# from datamodule create one of these per train/test/val


class Input4MipsDataset(ClimateDataset):
"""
Loads all scenarios for a given var / for all vars.
Expand Down
124 changes: 96 additions & 28 deletions climatem/data_loader/savar_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Optional
from typing import List, Optional, Sequence

import numpy as np
import torch
Expand All @@ -25,13 +25,25 @@ def __init__(
n_per_col: int = 2,
difficulty: str = "easy",
seasonality: bool = False,
periods: List[float] = [365, 182.5, 60],
amplitudes: List[float] = [0.06, 0.02, 0.01],
phases: List[float] = [0.0, 0.7853981634, 1.5707963268],
yearly_jitter_amp: float = 0.05,
yearly_jitter_phase: float = 0.10,
overlap: bool = False,
is_forced: bool = False,
f_1: int = 1,
f_2: int = 2,
f_time_1: int = 4000,
f_time_2: int = 8000,
ramp_type: str = "linear",
linearity: str = "linear",
poly_degrees: List[int] = [2, 3],
plot_original_data: bool = True,
):
super().__init__()
self.output_save_dir = Path(output_save_dir)
self.savar_name = f"modes_{n_per_col**2}_tl_{time_len}_isforced_{is_forced}_difficulty_{difficulty}_noisestrength_{noise_val}_seasonality_{seasonality}_overlap_{overlap}"
self.savar_name = f"modes_{n_per_col**2}_tl_{time_len}_isforced_{is_forced}_difficulty_{difficulty}_noisestrength_{noise_val}_seasonality_{seasonality}_overlap_{overlap}_f1_{f_1}_f2_{f_2}_ft1_{f_time_1}_ft2_{f_time_2}_ramp_{ramp_type}_linearity_{linearity}_polydegs_{poly_degrees}"
self.savar_path = self.output_save_dir / f"{self.savar_name}.npy"

self.global_normalization = global_normalization
Expand All @@ -49,8 +61,20 @@ def __init__(
self.n_per_col = n_per_col
self.difficulty = difficulty
self.seasonality = seasonality
self.periods = periods
self.amplitudes = amplitudes
self.phases = phases
self.yearly_jitter_amp = yearly_jitter_amp
self.yearly_jitter_phase = yearly_jitter_phase
self.overlap = overlap
self.is_forced = is_forced
self.f_1 = f_1
self.f_2 = f_2
self.f_time_1 = f_time_1
self.f_time_2 = f_time_2
self.ramp_type = ramp_type
self.linearity = linearity
self.poly_degrees = poly_degrees
self.plot_original_data = plot_original_data

if self.reload_climate_set_data:
Expand Down Expand Up @@ -171,8 +195,20 @@ def get_causal_data(
self.n_per_col,
self.difficulty,
self.seasonality,
self.periods,
self.amplitudes,
self.phases,
self.yearly_jitter_amp,
self.yearly_jitter_phase,
self.overlap,
self.is_forced,
self.f_1,
self.f_2,
self.f_time_1,
self.f_time_2,
self.ramp_type,
self.linearity,
self.poly_degrees,
self.plot_original_data,
)
time_steps = data.shape[1]
Expand All @@ -191,7 +227,14 @@ def get_causal_data(
if self.global_normalization:
data = (data - data.mean()) / data.std()
if self.seasonality_removal:
self.norm_data = self.remove_seasonality(self.norm_data)
data = self.remove_seasonality(
data,
periods=self.periods, # already a constructor arg (e.g. 12)
demean=True,
normalise=False,
rolling=True,
w=10, # 10 years ≈ 120 steps @ monthly
)

print(f"data is {data.dtype}")

Expand Down Expand Up @@ -371,35 +414,60 @@ def get_min_max(self, data):

return vars_min, vars_max

# important?
# NOTE:(seb) I need to check the axis is correct here?
def remove_seasonality(self, data):
def remove_seasonality(
self,
data: np.ndarray,
periods: int | Sequence[int] | Sequence[float] = (12, 6, 3),
demean: bool = True,
normalise: bool = False,
rolling: bool = True, # ← default TRUE because of jitter
w: int = 10, # (10 years ≈ 120 steps @ monthly)
):
"""
Function to remove seasonality from the data There are various different options to do this These are just
different methods of removing seasonality.
Remove deterministic periodic seasonality from a [time, …] array.

e.g.
monthly - remove seasonality on a per month basis
rolling monthly - remove seasonality on a per month basis but using a rolling window,
removing only the average from the months that have preceded this month
linear - remove seasonality using a linear model to predict seasonality

or trend removal
emissions - remove the trend using the emissions data, such as cumulative CO2
Parameters
----------
period single cycle length **or** list/tuple of lengths
(e.g. [12, 6] for annual + semi-annual)
"""

mean = np.nanmean(data, axis=0)
std = np.nanstd(data, axis=0)

# return data

# NOTE: SH - do we not do this above?
# standardise - I hope this is doing by month, to check

return (data - mean[None]) / std[None]

# now just divide by std...
# return data / std[None]
def _remove_one(x: np.ndarray, p: int) -> np.ndarray:
"""Inner helper that handles a single period length."""
t = x.shape[0]
rem = t % p
if rem:
x = x[:-rem]
t -= rem
folded = x.reshape((t // p, p) + x.shape[1:])
if rolling:
k = min(w, folded.shape[0])
mean = np.nanmean(folded[-k:], axis=0)
std = np.nanstd(folded[-k:], axis=0)
else:
mean = np.nanmean(folded, axis=0)
std = np.nanstd(folded, axis=0)
mean_full = np.tile(mean, (t // p, *[1] * (x.ndim - 1)))
std_full = np.tile(std, (t // p, *[1] * (x.ndim - 1)))
out = x.copy()
if demean:
out -= mean_full
if normalise:
out /= np.where(std_full == 0, 1, std_full)
return out.astype(np.float32)

# handle one or many cycle lengths
if isinstance(periods, (list, tuple, np.ndarray)):
# remove the longest cycle first to avoid leakage
_periods = sorted([int(round(p)) for p in periods], reverse=True)
else: # single scalar
_periods = [int(round(periods))]

out = data.astype(np.float32)
for p in _periods:
out = _remove_one(out, p)
return out

def write_dataset_statistics(self, fname, stats):
# fname = fname.replace('.npz.npy', '.npy')
Expand Down
4 changes: 4 additions & 0 deletions climatem/model/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ def __init__(
self,
model,
datamodule,
data_params,
exp_params,
gt_params,
model_params,
train_params,
optim_params,
plot_params,
savar_params,
save_path,
plots_path,
best_metrics,
Expand All @@ -45,6 +47,7 @@ def __init__(
self.data_loader_train = iter(datamodule.train_dataloader(accelerator=accelerator))
self.data_loader_val = iter(datamodule.val_dataloader())
self.coordinates = datamodule.coordinates
self.data_params = data_params
self.exp_params = exp_params
self.train_params = train_params
self.optim_params = optim_params
Expand All @@ -55,6 +58,7 @@ def __init__(
)

self.plot_params = plot_params
self.savar_params = savar_params
self.best_metrics = best_metrics
self.save_path = save_path
self.plots_path = plots_path
Expand Down
Loading