Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,6 @@ notebooks/plots/
notebooks/rollouts/

scripts/run_single_*.txt

.DS_Store
scripts/.DS_Store
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ If you need a package that is not automatically installed, please run `poetry ad

### 3. Downloading input data

For running the model on real cliamte data, please download monthly climate model data and regrid it to an icosahedral grid using ClimateSet https://github.com/RolnickLab/ClimateSet.
For running the model on real climate data, please download monthly climate model data and regrid it to an icosahedral grid using ClimateSet https://github.com/RolnickLab/ClimateSet.
If you have any problem downloading or formatting the data, please get in touch.

### 4. Running the model
Expand Down
18 changes: 13 additions & 5 deletions climatem/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from pathlib import Path
import sys

APP_ROOT = Path(__file__).resolve().parent
PROJECT_ROOT = APP_ROOT.parent
DATA_DIR = PROJECT_ROOT / "data"
DATA_DIR = PROJECT_ROOT / "data" # doesn't exist
SCRIPTS_DIR = PROJECT_ROOT / "scripts"
CONFIGS_PATH = SCRIPTS_DIR / "configs"
PARAMS_PATH = SCRIPTS_DIR / "params"
MAPPINGS_DIR = PROJECT_ROOT / "climatem" / "mappings"
TUNING_CONFIGS = CONFIGS_PATH / "tuning"
CONFIGS_PATH = PROJECT_ROOT / "configs"
PARAMS_PATH = PROJECT_ROOT / "params" # doesn't exist
MAPPINGS_DIR = PROJECT_ROOT / "mappings"
TUNING_CONFIGS = CONFIGS_PATH / "tuning" # doesn't exist

if sys.platform == "linux":
SCRATCH_DIR = Path.home() / "scratch"
print("Climatem:Detected Linux system, using scratch directory: ", SCRATCH_DIR)
else:
SCRATCH_DIR = PROJECT_ROOT.parent.parent / "scratch" # hardcoded for local machine
print("Climatem:Detected non-Linux system, using scratch directory: ", SCRATCH_DIR)
2 changes: 2 additions & 0 deletions climatem/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
fixed_output_fraction=None, # NOT SURE, Remove this?
tau_neigh: int = 0, # NOT SURE
hard_gumbel: bool = False, # NOT SURE
use_grad_norm: bool = False,
):
self.instantaneous = instantaneous
self.no_w_constraint = no_w_constraint
Expand All @@ -179,6 +180,7 @@ def __init__(
self.fixed_output_fraction = fixed_output_fraction
self.tau_neigh = tau_neigh
self.hard_gumbel = hard_gumbel
self.use_grad_norm = use_grad_norm


class optimParams:
Expand Down
28 changes: 17 additions & 11 deletions climatem/data_loader/causal_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@ class CausalClimateDataModule(ClimateDataModule):
The setup method is overwritten and performs data preprocessing for causal discovery models.
"""

def __init__(self, tau=5, future_timesteps=1, num_months_aggregated=1, train_val_interval_length=100, **kwargs):
def __init__(self, tau=5, future_timesteps=1, num_months_aggregated=1, train_val_interval_length=100, spatial_resolution=10, num_modes=4, dimensions=1600, **kwargs):
super().__init__(self)

# kwargs are initialized as self.hparams by the Lightning module
# WHat is this line? We cannot have different test vs train models
# self.hparams.test_models = None if self.hparams.test_models else self.hparams.train_models
self.spatial_resolution = spatial_resolution
self.dimensions = dimensions
self.hparams.test_models = self.hparams.train_models
self.tau = tau
self.num_modes = num_modes
self.future_timesteps = future_timesteps
self.num_months_aggregated = num_months_aggregated
self.train_val_interval_length = train_val_interval_length
Expand Down Expand Up @@ -82,7 +85,7 @@ def setup(self, stage: Optional[str] = None):
# TODO: propagate "reload argument here"
# TODO: make sure all arguments are propagated i.e. seasonality_removal, output_save_dir
if "savar" in self.hparams.in_var_ids:
train_val_input4mips = SavarDataset(
self.train_val_input4mips = SavarDataset(
# Make sure these arguments are propagated
output_save_dir=self.hparams.output_save_dir,
lat=self.hparams.lat,
Expand All @@ -100,6 +103,7 @@ def setup(self, stage: Optional[str] = None):
overlap=self.hparams.overlap,
is_forced=self.hparams.is_forced,
plot_original_data=self.hparams.plot_original_data,
seed=self.hparams.seed,
)
elif (
"tas" in self.hparams.in_var_ids
Expand All @@ -111,7 +115,7 @@ def setup(self, stage: Optional[str] = None):
print(
f"Causal datamodule self.hparams.icosahedral_coordinates_path {self.hparams.icosahedral_coordinates_path}"
)
train_val_input4mips = CMIP6Dataset(
self.train_val_input4mips = CMIP6Dataset(
years=train_years,
historical_years=train_historical_years,
data_dir=self.hparams.data_dir,
Expand All @@ -131,7 +135,7 @@ def setup(self, stage: Optional[str] = None):
reload_climate_set_data=self.hparams.reload_climate_set_data,
)
elif "t2m" in self.hparams.in_var_ids:
train_val_input4mips = ERA5Dataset(
self.train_val_input4mips = ERA5Dataset(
years=train_years,
historical_years=train_historical_years,
data_dir=self.hparams.data_dir,
Expand All @@ -151,7 +155,7 @@ def setup(self, stage: Optional[str] = None):
reload_climate_set_data=self.hparams.reload_climate_set_data,
)
else:
train_val_input4mips = Input4MipsDataset(
self.train_val_input4mips = Input4MipsDataset(
years=train_years,
historical_years=train_historical_years,
data_dir=self.hparams.data_dir,
Expand All @@ -171,7 +175,7 @@ def setup(self, stage: Optional[str] = None):

ratio_train = 1 - self.hparams.val_split

train, val = train_val_input4mips.get_causal_data(
train, val = self.train_val_input4mips.get_causal_data(
tau=self.tau,
future_timesteps=self.future_timesteps,
channels_last=self.hparams.channels_last,
Expand All @@ -185,9 +189,11 @@ def setup(self, stage: Optional[str] = None):
mode="train+val",
)
if "savar" in self.hparams.in_var_ids:
self.savar_gt_modes = train_val_input4mips.gt_modes
self.savar_gt_noise = train_val_input4mips.gt_noise
self.savar_gt_adj = train_val_input4mips.gt_adj
self.savar_gt_modes_weights = self.train_val_input4mips.gt_modes_weights
self.savar_gt_modes = self.train_val_input4mips.gt_modes
self.savar_gt_noise = self.train_val_input4mips.gt_noise
self.savar_gt_adj = self.train_val_input4mips.gt_adj
self.savar_links_coeffs = self.train_val_input4mips.links_coeffs

train_x, train_y = train
train_x = train_x.reshape((train_x.shape[0], train_x.shape[1], train_x.shape[2], -1))
Expand All @@ -203,7 +209,7 @@ def setup(self, stage: Optional[str] = None):
val_y = val_y.reshape((val_y.shape[0], val_y.shape[1], val_y.shape[2], -1))
self._data_val = CausalDataset(val_x, val_y)

self.coordinates = train_val_input4mips.coordinates
self.coordinates = self.train_val_input4mips.coordinates

if stage in ["test", None]:
openburning_specs = {
Expand All @@ -213,4 +219,4 @@ def setup(self, stage: Optional[str] = None):
else OPENBURNING_MODEL_MAPPING["other"]
)
for test_model in self.hparams.test_models
}
}
36 changes: 27 additions & 9 deletions climatem/data_loader/climate_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,14 @@ def setup(self, stage: Optional[str] = None):
fractions = [1 - +self.hparams.val_split, self.hparams.val_split]
ds_list = random_split(full_ds, lengths=fractions)
train_ds, val_ds = ds_list

train_ds.Data = train_ds.Data.to(self.accelerator.device)
print ("train_ds device: ", train_ds.Data.device)
val_ds.Data = val_ds.Data.to(self.accelerator.device)

self._data_train = train_ds
self._data_val = val_ds

# Test sets:
if stage == "test" or stage is None:
self._data_test = [
Expand Down Expand Up @@ -219,27 +225,39 @@ def train_dataloader(self, accelerator):
# # setup_ddp()
# train_sampler = DistributedSampler(dataset=self._data_train, shuffle=True)

return DataLoader(
# Set generator seed for reproducibility
generator = torch.Generator(device=accelerator.device)
# generator = torch.Generator()
generator.manual_seed(self.hparams.seed)
# print("self._data_train device: ", self._data_train.Data.device)
# self._data_train.x = torch.from_numpy(self._data_train.x).to(accelerator.device)
# self._data_train.y = torch.from_numpy(self._data_train.y).to(accelerator.device)
# breakpoint()
dl = DataLoader(
dataset=self._data_train,
batch_size=self.hparams.batch_size,
shuffle=True,
generator=torch.Generator(device=accelerator.device),
# generator=torch.Generator(device=accelerator.device),
generator=generator,
drop_last=True,
**self._shared_dataloader_kwargs(),
)
print ("dl construction worked")
dl = accelerator.prepare(dl)
return dl

def val_dataloader(self):
def val_dataloader(self, accelerator):

# valid_sampler = None
# if multi_gpu:
# # setup_ddp()
# valid_sampler = DistributedSampler(dataset=self._data_val, shuffle=False)
dl = DataLoader(dataset=self._data_val, drop_last=True, **self._shared_eval_dataloader_kwargs())
# self._data_val.x = torch.from_numpy(self._data_val.x).to(accelerator.device)
# self._data_val.y = torch.from_numpy(self._data_val.y).to(accelerator.device)

return (
DataLoader(dataset=self._data_val, drop_last=True, **self._shared_eval_dataloader_kwargs())
if self._data_val is not None
else None
)
dl = accelerator.prepare(dl)
return dl

def test_dataloader(self) -> List[DataLoader]:

Expand All @@ -263,4 +281,4 @@ def predict_dataloader(self) -> EVAL_DATALOADERS:
if self._data_val is not None
else None
)
]
]
6 changes: 5 additions & 1 deletion climatem/data_loader/climate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

log = get_logger()

device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")


# base data set: implements copy to slurm, get item etc pp
# cmip6 data set: model wise
Expand Down Expand Up @@ -73,6 +75,7 @@ def __init__(
"""

super().__init__()
self.Data = None # used by inhereted classes to store the data (e.g. input4mips)
self.test_dir = output_save_dir
self.output_save_dir = Path(output_save_dir)
self.reload_climate_set_data = reload_climate_set_data
Expand Down Expand Up @@ -331,7 +334,8 @@ def get_causal_data(
num_years = self.length
# print("In get_causal_data, num_years:", num_years)

data = self.Data
data = self.Data.to(device)
print("data device: ", data.device)

# print("Here in get_causal_data, self.length:", self.length)

Expand Down
38 changes: 30 additions & 8 deletions climatem/data_loader/savar_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ def __init__(
overlap: bool = False,
is_forced: bool = False,
plot_original_data: bool = True,
seed: int = 42,
):
super().__init__()
self.output_save_dir = Path(output_save_dir)
self.savar_name = f"modes_{n_per_col**2}_tl_{time_len}_isforced_{is_forced}_difficulty_{difficulty}_noisestrength_{noise_val}_seasonality_{seasonality}_overlap_{overlap}"
# self.savar_name = f"modes_{n_per_col**2}_tl_{time_len}_isforced_{is_forced}_difficulty_{difficulty}_noisestrength_{noise_val}_seasonality_{seasonality}_overlap_{overlap}"
self.savar_name = f"modes_{n_per_col**2}-diff_{difficulty}-seed_{seed}"
self.savar_path = self.output_save_dir / f"{self.savar_name}.npy"

self.global_normalization = global_normalization
Expand All @@ -41,6 +43,7 @@ def __init__(
# TODO: for now this is ok, we create a square grid. Later we might want to look at icosahedral grid :)
self.lat = lat
self.lon = lon
self.tau = tau
self.coordinates = np.array(np.meshgrid(np.arange(self.lat), np.arange(self.lon))).reshape((2, -1)).T

self.time_len = time_len
Expand All @@ -56,15 +59,17 @@ def __init__(
if self.reload_climate_set_data:
self.gt_modes = np.load(self.output_save_dir / f"{self.savar_name}_modes.npy")
self.gt_noise = np.load(self.output_save_dir / f"{self.savar_name}_noise_modes.npy")
links_coeffs = np.load(
self.links_coeffs = np.load(
self.output_save_dir / f"{self.savar_name}_parameters.npy", allow_pickle=True
).item()["links_coeffs"]
self.gt_adj = np.array(extract_adjacency_matrix(links_coeffs, n_per_col**2, tau))[::-1]
self.gt_adj = np.array(extract_adjacency_matrix(self.links_coeffs, self.n_per_col**2, self.tau))
self.gt_modes_weights = np.load(self.output_save_dir / f"{self.savar_name}_mode_weights.npy")
else:
self.gt_modes = None
self.gt_noise = None
links_coeffs = None
self.links_coeffs = None
self.gt_adj = None
self.gt_modes_weights = None

@staticmethod
def aggregate_months(data, num_months_aggregated):
Expand Down Expand Up @@ -174,16 +179,18 @@ def get_causal_data(
self.overlap,
self.is_forced,
self.plot_original_data,
tau,
)
time_steps = data.shape[1]
data = data.T.reshape((time_steps, self.lat, self.lon))

self.gt_modes_weights = np.load(self.output_save_dir / f"{self.savar_name}_mode_weights.npy")
self.gt_modes = np.load(self.output_save_dir / f"{self.savar_name}_modes.npy")
self.gt_noise = np.load(self.output_save_dir / f"{self.savar_name}_noise_modes.npy")
links_coeffs = np.load(
self.links_coeffs = np.load(
self.output_save_dir / f"{self.savar_name}_parameters.npy", allow_pickle=True
).item()["links_coeffs"]
self.gt_adj = np.array(extract_adjacency_matrix(links_coeffs, self.n_per_col**2, tau))
self.gt_adj = np.array(extract_adjacency_matrix(self.links_coeffs, self.n_per_col**2, tau))

data = data.astype("float32")
# TODO: normalize by saveing std/mean from train data and then normalize test by reloading
Expand Down Expand Up @@ -294,11 +301,26 @@ def get_causal_data(
x_valid, y_valid = self.get_overlapping_sequences(scenario, idx_valid, tau, future_timesteps)
x_valid_list.extend(x_valid)
y_valid_list.extend(y_valid)
train_x, train_y = np.stack(x_train_list), np.stack(y_train_list)
future_timesteps = y_train_list[0].shape[0]

# because of sliding window, last y_train doesn't have right number of future time steps
# find last y_train that has the right number of future time steps
# use case if future time steps > 1

idx = len(y_train_list)
for i in range(len(y_train_list), 0):
if y_train_list[i].shape[0] == future_timesteps:
idx = i
break

train_x, train_y = np.stack(x_train_list[:idx]), np.stack(y_train_list[:idx])
if ratio_train == 1:
valid_x, valid_y = np.array(x_valid_list), np.array(y_valid_list)
else:
valid_x, valid_y = np.stack(x_valid_list), np.stack(y_valid_list)
future_timesteps = y_valid_list[0].shape[0]
valid_x, valid_y = np.stack(x_valid_list[:-future_timesteps]), np.stack(
y_valid_list[:-future_timesteps]
)
train_y = np.expand_dims(train_y, axis=1)
valid_y = np.expand_dims(valid_y, axis=1)

Expand Down
Loading