Skip to content

Commit

Permalink
Merge pull request #211 from UrbanSystemsLab/params-dataclass
Browse files Browse the repository at this point in the history
AtmoModel.Params dataclass
  • Loading branch information
Katsutoshii authored Feb 24, 2025
2 parents 050060f + 935cb8c commit 10108ea
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 2,769 deletions.
2,740 changes: 136 additions & 2,604 deletions usl_models/notebooks/train_atmo_model.ipynb

Large diffs are not rendered by default.

34 changes: 15 additions & 19 deletions usl_models/tests/atmo_ml/atmo_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,14 @@

def pytest_model_params() -> atmo_model.AtmoModel.Params:
"""Defines AtmoModel.Params for testing."""
params = atmo_model.AtmoModel.default_params()
params.update(
{
"batch_size": 4,
"lstm_units": 32,
"lstm_kernel_size": 3,
# Use faster optimizer setting for early stopping.
"optimizer_config": keras.optimizers.Adam(
learning_rate=1e-3,
global_clipnorm=0.1,
),
}
return atmo_model.AtmoModel.Params(
output_timesteps=2,
lstm_units=32,
lstm_kernel_size=3,
optimizer=keras.optimizers.Adam(
learning_rate=1e-3,
),
)
return params


def fake_input_batch(
Expand Down Expand Up @@ -68,11 +62,13 @@ def fake_input_batch(
maxval=_LU_INDEX_VOCAB_SIZE,
dtype=tf.int32,
)
return {
"spatial": spatial,
"spatiotemporal": spatiotemporal,
"lu_index": lu_index,
}
return atmo_model.AtmoModel.Input(
spatial=spatial,
spatiotemporal=spatiotemporal,
lu_index=lu_index,
sim_name=tf.constant(["test"] * batch_size),
date=tf.constant(["test"] * batch_size),
)


def test_atmo_convlstm():
Expand All @@ -87,7 +83,7 @@ def test_atmo_convlstm():

expected_output_shape = (
batch_size,
params["output_timesteps"],
params.output_timesteps,
_TEST_MAP_HEIGHT,
_TEST_MAP_WIDTH,
constants.OUTPUT_CHANNELS, # T2, RH2, WSPD10, WDIR10_SIN, WDIR10_COS
Expand Down
2 changes: 1 addition & 1 deletion usl_models/tests/atmo_ml/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
F_S = constants.NUM_SAPTIAL_FEATURES
F_ST = constants.NUM_SPATIOTEMPORAL_FEATURES
C = constants.OUTPUT_CHANNELS
T_I, T_O = constants.INPUT_TIME_STEPS, 1
T_I, T_O = constants.INPUT_TIME_STEPS, 2


class TestAtmoMLDataset(usl_models.testing.TestCase):
Expand Down
51 changes: 0 additions & 51 deletions usl_models/usl_models/atmo_ml/constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""Constant definitions for the AtmoML model."""

import tensorflow as tf


# Geospatial constants
MAP_HEIGHT = 200
MAP_WIDTH = 200
Expand All @@ -16,51 +13,3 @@
TIME_STEPS_PER_DAY = 4
LU_INDEX_VOCAB_SIZE = 61
EMBEDDING_DIM = 8


def get_input_shape_batched(height, width):
spec = get_input_spec(height, width)
return {k: (None, *v.shape) for k, v in spec.items()}


def get_input_spec(height: int | None, width: int | None) -> dict[str, tf.TypeSpec]:
return {
"spatiotemporal": tf.TensorSpec(
shape=(
INPUT_TIME_STEPS,
height,
width,
NUM_SPATIOTEMPORAL_FEATURES,
),
dtype=tf.float32,
),
"spatial": tf.TensorSpec(
shape=(
height,
width,
NUM_SAPTIAL_FEATURES,
),
dtype=tf.float32,
),
"lu_index": tf.TensorSpec(
shape=(
height,
width,
),
dtype=tf.int32,
),
"sim_name": tf.TensorSpec(shape=(), dtype=tf.string),
"date": tf.TensorSpec(shape=(), dtype=tf.string),
}


def get_output_spec(height: int, width: int, timesteps: int) -> tf.TensorSpec:
return tf.TensorSpec(
shape=(
timesteps,
height,
width,
OUTPUT_CHANNELS,
),
dtype=tf.float32,
)
35 changes: 16 additions & 19 deletions usl_models/usl_models/atmo_ml/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from google.cloud import storage # type: ignore
from google.cloud.storage import transfer_manager # type: ignore

from usl_models.atmo_ml import constants, vars
from usl_models.atmo_ml import constants, vars, model
from usl_models.shared import downloader


Expand All @@ -35,11 +35,11 @@

@dataclasses.dataclass(kw_only=True, frozen=True)
class Config:
input_width: int = 200
input_height: int = 200
output_width: int = 200
output_height: int = 200
output_timesteps: int = 1
input_width: int = constants.MAP_WIDTH
input_height: int = constants.MAP_HEIGHT
output_width: int = constants.MAP_WIDTH
output_height: int = constants.MAP_HEIGHT
output_timesteps: int = constants.OUTPUT_TIME_STEPS


def get_date(filename: str) -> str:
Expand Down Expand Up @@ -107,14 +107,11 @@ def get_cached_sim_dates(path: pathlib.Path) -> list[tuple[str, str]]:

def get_output_signature(
config: Config,
) -> tuple[dict[str, tf.TypeSpec], tf.TensorSpec]:
) -> tuple[model.AtmoModel.InputSpec, tf.TensorSpec]:
params = model.AtmoModel.Params(output_timesteps=config.output_timesteps)
return (
constants.get_input_spec(height=config.input_height, width=config.input_width),
constants.get_output_spec(
height=config.output_height,
width=config.output_width,
timesteps=config.output_timesteps,
),
model.AtmoModel.get_input_spec(params),
model.AtmoModel.get_output_spec(params),
)


Expand All @@ -132,7 +129,7 @@ def load_dataset_cached(
if shuffle:
random.shuffle(example_keys)

def generator() -> Iterable[tuple[dict[str, tf.Tensor], tf.Tensor]]:
def generator() -> Iterable[tuple[model.AtmoModel.Input, tf.Tensor]]:
missing_days: int = 0
generated_count: int = 0
if example_keys is None:
Expand Down Expand Up @@ -219,7 +216,7 @@ def load_dataset(
hash_range,
)

def generator() -> Iterable[tuple[dict[str, tf.Tensor], tf.Tensor]]:
def generator() -> Iterable[tuple[model.AtmoModel.Input, tf.Tensor]]:
missing_days: int = 0
generated_count: int = 0

Expand Down Expand Up @@ -260,7 +257,7 @@ def load_day(
feature_bucket: storage.Bucket,
label_bucket: storage.Bucket,
config: Config,
) -> tuple[dict[str, tf.Tensor], tf.Tensor] | None:
) -> tuple[model.AtmoModel.Input, tf.Tensor] | None:
"""Loads a single example from (sim_name, date)."""
logging.info("load_day('%s', '%s')" % (sim_name, date.strftime(DATE_FORMAT)))
start_filename = date.strftime(FEATURE_FILENAME_FORMAT)
Expand Down Expand Up @@ -291,7 +288,7 @@ def load_day(
return None

return (
dict(
model.AtmoModel.Input(
spatiotemporal=spatiotemporal_data,
spatial=spatial_data,
lu_index=lu_index_data,
Expand Down Expand Up @@ -353,7 +350,7 @@ def load_day_label(
@functools.lru_cache(maxsize=128)
def load_day_cached(
filecache_dir: pathlib.Path, sim_name: str, date: datetime, config: Config
) -> tuple[dict[str, tf.Tensor], tf.Tensor] | None:
) -> tuple[model.AtmoModel.Input, tf.Tensor] | None:
spatiotemporal = load_day_spatiotemporal_cached(
filecache_dir / sim_name / "spatiotemporal", date, config
)
Expand All @@ -377,7 +374,7 @@ def load_day_cached(
)

return (
dict(
model.AtmoModel.Input(
spatiotemporal=spatiotemporal,
spatial=tf.convert_to_tensor(spatial, dtype=tf.float32),
lu_index=tf.convert_to_tensor(lu_index, dtype=tf.int32),
Expand Down
Loading

0 comments on commit 10108ea

Please sign in to comment.