Skip to content

Commit

Permalink
Merge pull request #215 from UrbanSystemsLab/htune
Browse files Browse the repository at this point in the history
Support hyperparameter tuning.
  • Loading branch information
Katsutoshii authored Mar 3, 2025
2 parents 10108ea + baec189 commit 41d48b5
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 22 deletions.
147 changes: 147 additions & 0 deletions usl_models/notebooks/htune_atmo_model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AtmoML Hyperparameter Tuning"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"import keras_tuner\n",
"import keras\n",
"import time\n",
"import pathlib\n",
"\n",
"from usl_models.atmo_ml.model import AtmoModel\n",
"from usl_models.atmo_ml import dataset, visualizer, vars\n",
"\n",
"\n",
"logging.getLogger().setLevel(logging.WARNING)\n",
"keras.utils.set_random_seed(812)\n",
"visualizer.init_plt()\n",
"\n",
"batch_size = 8\n",
"filecache_dir = pathlib.Path(\"/home/shared/climateiq/filecache\")\n",
"example_keys = [\n",
" (\"NYC_Heat_Test/NYC_summer_2000_01p\", \"2000-05-25\"),\n",
" (\"NYC_Heat_Test/NYC_summer_2000_01p\", \"2000-05-26\"),\n",
" (\"NYC_Heat_Test/NYC_summer_2000_01p\", \"2000-05-27\"),\n",
" (\"NYC_Heat_Test/NYC_summer_2000_01p\", \"2000-05-28\"),\n",
" (\"PHX_Heat_Test/PHX_summer_2008_25p\", \"2008-05-25\"),\n",
" (\"PHX_Heat_Test/PHX_summer_2008_25p\", \"2008-05-26\"),\n",
" (\"PHX_Heat_Test/PHX_summer_2008_25p\", \"2008-05-27\"),\n",
" (\"PHX_Heat_Test/PHX_summer_2008_25p\", \"2008-05-28\"),\n",
"]\n",
"timestamp = time.strftime(\"%Y%m%d-%H%M%S\")\n",
"\n",
"ds_config = dataset.Config(output_timesteps=2)\n",
"train_ds = dataset.load_dataset_cached(\n",
" filecache_dir,\n",
" example_keys=example_keys,\n",
" config=ds_config,\n",
").batch(batch_size=batch_size)\n",
"val_ds = dataset.load_dataset_cached(\n",
" filecache_dir,\n",
" example_keys=example_keys,\n",
" config=ds_config,\n",
" shuffle=False,\n",
").batch(batch_size=batch_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tuner = keras_tuner.BayesianOptimization(\n",
" AtmoModel.get_hypermodel(\n",
" input_cnn_kernel_size=[1, 2, 5],\n",
" lstm_kernel_size=[5],\n",
" ),\n",
" objective=\"val_loss\",\n",
" max_trials=10,\n",
" project_name=f\"logs/htune_project_{timestamp}\",\n",
")\n",
"tuner.search_space_summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"log_dir = f\"logs/htune_{timestamp}\"\n",
"print(log_dir)\n",
"tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)\n",
"tuner.search(train_ds, epochs=100, validation_data=val_ds, callbacks=[tb_callback])\n",
"best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]\n",
"best_hp.values"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Train the best option further and save.\n",
"model = AtmoModel(model=best_model)\n",
"tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)\n",
"model.fit(train_ds, val_ds, epochs=200, callbacks=[tb_callback], validation_freq=10)\n",
"model.save_model(log_dir + \"/model\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# Plot results\n",
"model = AtmoModel.from_checkpoint(log_dir + \"/model\")\n",
"input_batch, label_batch = next(iter(val_ds))\n",
"pred_batch = model.call(input_batch)\n",
"\n",
"for fig in visualizer.plot_batch(\n",
" input_batch=input_batch,\n",
" label_batch=label_batch,\n",
" pred_batch=pred_batch,\n",
" st_var=vars.Spatiotemporal.RH,\n",
" sto_var=vars.SpatiotemporalOutput.RH2,\n",
" max_examples=None,\n",
"):\n",
" fig.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
5 changes: 3 additions & 2 deletions usl_models/notebooks/train_atmo_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
" (\"PHX_Heat_Test/PHX_summer_2008_25p\", \"2008-05-27\"),\n",
" (\"PHX_Heat_Test/PHX_summer_2008_25p\", \"2008-05-28\"),\n",
"]\n",
"timestamp = time.strftime(\"%Y%m%d-%H%M%S\")\n",
"\n",
"ds_config = dataset.Config(\n",
" output_timesteps=2)\n",
Expand Down Expand Up @@ -84,10 +85,10 @@
"source": [
"# Train the model\n",
"# Create a unique log directory by appending the current timestamp\n",
"log_dir = os.path.join(\"./logs\", \"run_k5_\" + time.strftime(\"%Y%m%d-%H%M%S\"))\n",
"log_dir = os.path.join(\"./logs\", \"run_\" + timestamp)\n",
"print(log_dir)\n",
"tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)\n",
"model.fit(train_ds, val_ds, epochs=10, callbacks=[tb_callback], validation_freq=10)\n",
"model.fit(train_ds, val_ds, epochs=100, callbacks=[tb_callback], validation_freq=10)\n",
"model.save_model(log_dir + \"/model\")"
]
},
Expand Down
1 change: 1 addition & 0 deletions usl_models/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ matplotlib==3.9.2
numpy==1.26.4
tensorflow[and-cuda]==2.15.1
keras==2.15.0
keras-tuner[baysian]==1.4.7
google-cloud-aiplatform==1.43.0
google-cloud-storage==2.15.0
google-cloud-firestore==2.15.0
Expand Down
1 change: 1 addition & 0 deletions usl_models/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# firestore is not present in the image, but we match the cloud-storage version.
"google-cloud-firestore==2.15.0",
"seaborn==0.13.2",
"keras-tuner[bayesian]==1.4.7",
],
extras_require={
"dev": [
Expand Down
49 changes: 29 additions & 20 deletions usl_models/usl_models/atmo_ml/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""AtmoML model definition."""

import logging
import dataclasses
from typing import TypedDict, List, Callable, Literal

import keras
from keras import layers
import keras_tuner
import tensorflow as tf

from usl_models.atmo_ml import data_utils
Expand All @@ -29,15 +31,18 @@ class AtmoModel:
class Params(keras_dataclasses.Base):
"""Model parameters."""

# Layer-specific parameters.
# Input CNN params
input_cnn_kernel_size: int = 5

# LSTM parameters.
lstm_units: int = 64
lstm_kernel_size: int = 5
lstm_dropout: float = 0.2
lstm_recurrent_dropout: float = 0.2

# The optimizer configuration.
optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam(
learning_rate=1e-3
optimizer: keras.optimizers.Optimizer = dataclasses.field(
default_factory=lambda: keras.optimizers.Adam(learning_rate=1e-3)
)

output_timesteps: int = constants.OUTPUT_TIME_STEPS
Expand Down Expand Up @@ -110,25 +115,19 @@ def get_output_spec(cls, params: Params) -> tf.TensorSpec:
dtype=tf.float32,
)

def __init__(
self,
params: Params | None = None,
):
def __init__(self, params: Params | None = None, model: keras.Model | None = None):
"""Creates the Atmo model.
Args:
params: A dictionary of configurable model parameters.
spatial_dims: Tuple of spatial height and width input dimensions.
Needed for defining input shapes. This is an optional arg that
can be changed (primarily for testing/debugging).
num_spatial_features: nb of spt features
num_spatiotemporal_features: nb of spatiotemp feat.
lu_index_vocab_size: Number of unique values in the lu_index
feature.
embedding_dim: Size of the embedding vectors for lu_index.
params: Model parameters
model: If you already have a keras.Model constructed, pass it here.
"""
self._params = params or self.Params()
self._model = self._build_model()
if model is not None:
self._params = model._params # type: ignore
self._model = model
else:
self._params = params or self.Params()
self._model = self._build_model()

@classmethod
def from_checkpoint(cls, artifact_uri: str, **kwargs) -> "AtmoModel":
Expand All @@ -151,6 +150,16 @@ def from_checkpoint(cls, artifact_uri: str, **kwargs) -> "AtmoModel":
model._model.set_weights(loaded_model.get_weights())
return model

@classmethod
def get_hypermodel(cls, **kwargs) -> keras_tuner.HyperModel:
"""Returns a hypermodel with the given param overrides."""

def hypermodel(hp: keras_tuner.HyperParameters):
hp_kwargs = {k: hp.Choice(k, v) for k, v in kwargs.items()}
return cls(cls.Params(**hp_kwargs))._model

return hypermodel

def _build_model(self) -> keras.Model:
"""Creates the correct internal (Keras) model architecture."""
model = AtmoConvLSTM(self._params)
Expand Down Expand Up @@ -295,7 +304,7 @@ def __init__(self, params: AtmoModel.Params):

# Model definition
T, H, W = None, None, None
K_SIZE = self._params.lstm_kernel_size # Conv kernel size
K_SIZE = self._params.input_cnn_kernel_size
C1_STRIDE, C2_STRIDE = (
self._params.conv1_stride,
self._params.conv2_stride,
Expand Down Expand Up @@ -372,7 +381,7 @@ def __init__(self, params: AtmoModel.Params):
layers.InputLayer((T, LSTM_H, LSTM_W, LSTM_C)),
layers.ConvLSTM2D(
LSTM_FILTERS,
K_SIZE,
self._params.lstm_kernel_size,
return_sequences=True,
strides=1,
padding="same",
Expand Down

0 comments on commit 41d48b5

Please sign in to comment.