Skip to content

Commit

Permalink
removed legacy np.random calls
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Oct 4, 2024
1 parent 54fdbd7 commit 346da44
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 64 deletions.
16 changes: 0 additions & 16 deletions mlclouds/_version.py

This file was deleted.

86 changes: 44 additions & 42 deletions mlclouds/autoxval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from mlclouds.utilities import ALL_SKY_VARS, CONFIG, FP_DATA, surf_meta
from mlclouds.validator import Validator

RANDOM_GENERATOR = np.random.Generator(np.random.PCG64(42)

logger = logging.getLogger(__name__)

Check failure on line 25 in mlclouds/autoxval.py

View workflow job for this annotation

GitHub Actions / Ruff

Ruff (E999)

mlclouds/autoxval.py:25:1: E999 SyntaxError: Expected ',', found name

Check failure on line 26 in mlclouds/autoxval.py

View workflow job for this annotation

GitHub Actions / Ruff

Ruff (E999)

mlclouds/autoxval.py:25:37: E999 SyntaxError: Expected ')', found newline

Expand Down Expand Up @@ -143,8 +145,8 @@ def train(self, train_sites=[0, 1, 2, 3, 5, 6], train_files=FP_DATA):

self._model = trainer.model
self._train_data = trainer.train_data
self._config["train_files"] = self._train_data.train_files
self._config["train_sites"] = self._train_data.train_sites
self._config['train_files'] = self._train_data.train_files
self._config['train_sites'] = self._train_data.train_sites

def validate(
self,
Expand Down Expand Up @@ -175,7 +177,7 @@ def validate(
Save time series data to disk
"""
if self._model is None or self._config is None:
msg = "A model must be trained or loaded before validating."
msg = 'A model must be trained or loaded before validating.'
logger.critical(msg)
raise RuntimeError(msg)

Expand All @@ -201,7 +203,7 @@ def load_model(self, fname):
File name and path of pickle file
"""
self._model = MLCloudsModel.load(fname)
with open(fname + ".config", "rb") as f:
with open(fname + '.config', 'rb') as f:
self._config = json.load(f)

def save_model(self, fname):
Expand All @@ -214,7 +216,7 @@ def save_model(self, fname):
File name and path for pickle file
"""
self._model.save_model(fname)
with open(fname + ".config", "w") as f:
with open(fname + '.config', 'w') as f:
json.dump(self._config, f)

def save_stats(self, fname):
Expand All @@ -227,16 +229,16 @@ def save_stats(self, fname):
File name and path for stats CSV file
"""
if self.stats is None:
msg = "Statistics do not exist. Run XVal.validate() first"
msg = 'Statistics do not exist. Run XVal.validate() first'
logger.critical(msg)
raise RuntimeError(msg)

self.stats.to_csv(fname, index=False)
conf = deepcopy(self._config)

conf["val_files"] = self._validator.val_data.val_files
conf["val_files_meta"] = self._validator.val_data.files_meta
with open(fname[:-4] + ".json", "w") as f:
conf['val_files'] = self._validator.val_data.val_files
conf['val_files_meta'] = self._validator.val_data.files_meta
with open(fname[:-4] + '.json', 'w') as f:
json.dump(conf, f, indent=4)

def plot(self, gid):
Expand All @@ -248,15 +250,15 @@ def plot(self, gid):
gid: int
gid code of desired surfrad site to plot statistics for.
"""
code = surf_meta().loc[gid, "surfrad_id"]
for ylabel in ["MBE (%)", "MAE (%)", "RMSE (%)"]:
code = surf_meta().loc[gid, 'surfrad_id']
for ylabel in ['MBE (%)', 'MAE (%)', 'RMSE (%)']:
fig = px.bar(
self.stats[(self.stats.Site == code.upper())],
x="Condition",
x='Condition',
y=ylabel,
color="Model",
facet_col="Variable",
barmode="group",
color='Model',
facet_col='Variable',
barmode='group',
height=400,
)
fig.show()
Expand Down Expand Up @@ -309,8 +311,8 @@ def __init__(
save_timeseries: bool
Save time series data to disk
"""
if seed is not None:
np.random.seed(seed)
seed = 42 if seed is None else seed
rng = np.random.Generator(np.random.PCG64(seed))

if val_sites is None:
val_sites = sites
Expand All @@ -325,29 +327,29 @@ def __init__(
self._data_files = data_files

logger.info(
"AXV: training sites are {}, val sites are {}" "".format(
'AXV: training sites are {}, val sites are {}' ''.format(
sites, val_sites
)
)

self.temp_stats = None

if val_data is None:
logger.info("Loading validation data from {}".format(data_files))
logger.info('Loading validation data from {}'.format(data_files))
val_data = ValidationData(
val_files=data_files,
val_sites="all",
features=config["features"],
y_labels=config["y_labels"],
val_sites='all',
features=config['features'],
y_labels=config['y_labels'],
all_sky_vars=ALL_SKY_VARS,
one_hot_cats=config["one_hot_categories"],
one_hot_cats=config['one_hot_categories'],
)
self._val_data = val_data

for val_site in val_sites:
train_set = [x for x in sites if x != val_site]
if shuffle_train:
np.random.shuffle(train_set)
rng.shuffle(train_set)
self._run_train_set(
val_site,
train_set,
Expand Down Expand Up @@ -386,7 +388,7 @@ def _run_train_set(
Save time series data to disk
"""
logger.info(
"Training set is {}, val site is {}".format(train_set, val_site)
'Training set is {}, val site is {}'.format(train_set, val_site)
)
for i in range(min_train - 1, len(train_set)):
train_sites = train_set[0 : i + 1]
Expand All @@ -396,7 +398,7 @@ def _run_train_set(
except ArithmeticError as e:
if catch_nan:
logger.warning(
"Loss=nan, val on {}, train on {}" "".format(
'Loss=nan, val on {}, train on {}' ''.format(
val_site, train_sites
)
)
Expand All @@ -407,10 +409,10 @@ def _run_train_set(
)

# The _ prevents the 0 from being trimmed off
ts = "_" + "".join([str(x) for x in train_sites])
xv.stats["val_site"] = val_site
xv.stats["train_sites"] = ts
xv.stats["num_ts"] = len(ts) - 1
ts = '_' + ''.join([str(x) for x in train_sites])
xv.stats['val_site'] = val_site
xv.stats['train_sites'] = ts
xv.stats['num_ts'] = len(ts) - 1

if self.temp_stats is None:
self.temp_stats = xv.stats
Expand All @@ -419,7 +421,7 @@ def _run_train_set(

self.stats = self.temp_stats.reset_index()

def save_stats(self, path="./stats", fname=None):
def save_stats(self, path='./stats', fname=None):
"""
Save validation statistics csv and config as json
Expand All @@ -435,22 +437,22 @@ def save_stats(self, path="./stats", fname=None):
os.makedirs(path)

if fname is None:
sites_name = "".join([str(x) for x in self._sites])
val_name = "".join([str(x) for x in self._val_sites])
fname = "axv_stats_{}_{}".format(sites_name, val_name)
sites_name = ''.join([str(x) for x in self._sites])
val_name = ''.join([str(x) for x in self._val_sites])
fname = 'axv_stats_{}_{}'.format(sites_name, val_name)

fpath = os.path.join(path, fname)
self.stats.to_csv(fpath + ".csv")
logger.info("Saved stats to: {}".format(fpath + ".csv"))
self.stats.to_csv(fpath + '.csv')
logger.info('Saved stats to: {}'.format(fpath + '.csv'))

conf = deepcopy(self._config)
conf["data_files"] = self._data_files
conf["val_files"] = self._val_data.val_files
conf["sites"] = self._sites
conf["val_sites"] = self._val_sites
with open(fpath + ".json", "w") as f:
conf['data_files'] = self._data_files
conf['val_files'] = self._val_data.val_files
conf['sites'] = self._sites
conf['val_sites'] = self._val_sites
with open(fpath + '.json', 'w') as f:
json.dump(conf, f, indent=4)
logger.info("Saved config to: {}".format(fpath + ".json"))
logger.info('Saved config to: {}'.format(fpath + '.json'))

@classmethod
def k_fold(
Expand Down
2 changes: 1 addition & 1 deletion mlclouds/tdisc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
4) Water and Pressure were changed to vectors from scalars
"""

import tensorflow as tf
import numpy as np
import tensorflow as tf
from farms import SOLAR_CONSTANT, SZA_LIM


Expand Down
11 changes: 6 additions & 5 deletions tests/test_xval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
STAT_FIELDS = ["MAE (%)", "MBE (%)", "RMSE (%)"]


def check_for_eagle():
if not os.path.exists("/lustre/eaglefs/projects/pxs/mlclouds/"):
def check_for_hpc():
"""Check for projects directory. Skip this test if not found."""
if not os.path.exists("/projects/pxs/mlclouds/"):
msg = (
"These tests require access to /projects/pxs/mlclouds/ and "
"can only be run on the Eagle HPC"
"can only be run on the HPC"
)
pytest.skip(msg)

Expand All @@ -32,7 +33,7 @@ def test_xval():
Test that xval creates the proper results for a simple model. Also test
model saving and loading.
"""
check_for_eagle()
check_for_hpc()

config = CONFIG
config["epochs_a"] = 4
Expand Down Expand Up @@ -73,7 +74,7 @@ def test_xval_mismatched_timesteps():
Test training and validation with 5 minute and 30 minute GOES data at
the same time.
"""
check_for_eagle()
check_for_hpc()

config = CONFIG
config["epochs_a"] = 4
Expand Down

0 comments on commit 346da44

Please sign in to comment.