Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement loading dataset from GCP #196

Merged
merged 37 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
f90413f
adding testing files
SallyElHajjar Oct 29, 2024
d987ea3
adding new functions
SallyElHajjar Oct 31, 2024
373bf2c
adding new corrections
SallyElHajjar Oct 31, 2024
44cf2a4
adding new data
SallyElHajjar Nov 4, 2024
ba92ee1
Fix data loading from GCS and update test mocks
jainrajan98 Nov 5, 2024
ad65bf6
fixing flake8 errors and time steps
SallyElHajjar Nov 5, 2024
c992f2c
removing some subfunctions to module level and adding docs
SallyElHajjar Nov 6, 2024
964b1e7
testingwith true data
SallyElHajjar Nov 12, 2024
8d87a62
fixing black error
SallyElHajjar Nov 12, 2024
3b7aea0
improving the testing file
SallyElHajjar Nov 15, 2024
3c49619
Add Notebook
SallyElHajjar Nov 15, 2024
f02522a
Merge branch 'main' of https://github.com/UrbanSystemsLab/climateiq-c…
SallyElHajjar Nov 15, 2024
9d465a0
Merge branch 'dataset' of https://github.com/UrbanSystemsLab/climatei…
SallyElHajjar Nov 15, 2024
292e546
Update dataset notebook
Katsutoshii Nov 15, 2024
22ef360
Fix shape errors in dataset.py
Katsutoshii Nov 18, 2024
8a74aba
Fix shape issues for atmo ML training.
Katsutoshii Nov 20, 2024
8df4d23
correcting training
SallyElHajjar Nov 21, 2024
eb66377
remove mock url
SallyElHajjar Nov 22, 2024
2c74dc6
fix flake8
SallyElHajjar Nov 22, 2024
b4d70f9
Fix dataset test
Katsutoshii Nov 25, 2024
c1a91aa
Merge branch 'modelchange' of https://github.com/UrbanSystemsLab/clim…
Katsutoshii Nov 25, 2024
e15db9a
Remove print statments
Katsutoshii Nov 25, 2024
a1e1eb4
fixing flake8 error
SallyElHajjar Nov 27, 2024
d7ec502
fixing black error
SallyElHajjar Nov 27, 2024
3942a9f
fixing a typo in training notebook
SallyElHajjar Nov 27, 2024
cd8b9ef
testing atmo_utils
SallyElHajjar Nov 27, 2024
4423f18
Fixing shape errors in the model testing
SallyElHajjar Nov 27, 2024
6e4f4cc
Fix mypy errors on testing __init__.py.
Katsutoshii Nov 27, 2024
2a98fb8
fixing black error
SallyElHajjar Dec 2, 2024
8cbb012
Merge branch 'modelchange' of https://github.com/UrbanSystemsLab/clim…
SallyElHajjar Dec 2, 2024
46ab794
fixing black error
SallyElHajjar Dec 2, 2024
c98f0c2
Expose max_blobs in dataset API
Katsutoshii Dec 2, 2024
c95fd46
fixing flake8 error
SallyElHajjar Dec 2, 2024
483a59a
fixing flake8 error
SallyElHajjar Dec 2, 2024
aee4ef8
fixing black error
SallyElHajjar Dec 2, 2024
7129593
updating dataset_test
SallyElHajjar Dec 2, 2024
4ec6ce8
correcting downloader error
SallyElHajjar Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cloud_functions/test-local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
flake8 usl_pipeline/cloud_functions --show-source --statistics
black usl_pipeline/cloud_functions --check
pytest usl_pipeline/cloud_functions
mypy usl_pipeline/cloud_functions
mypy usl_pipeline/cloud_functions
6 changes: 6 additions & 0 deletions .github/workflows/usl_models/test-local.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# .github/workflows/usl_models/test-local.sh
# if it fails, we should give permission: chmod +x /home/elhajjas/climateiq-cnn/.github/workflows/usl_models/test-local.sh
flake8 usl_models --show-source --statistics
black usl_models --check
pytest usl_models -k "not integration"
mypy usl_models
6 changes: 3 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{
"jupyter.notebookFileRoot": "${fileDirname}/..",
"python.testing.pytestArgs": [
"usl_models", "--rootdir=usl_models", "-k", "not integration"
"usl_models"
],
"python.testing.cwd": "usl_models",
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"python.testing.pytestPath": "pytest"
}
345 changes: 345 additions & 0 deletions usl_models/notebooks/train_atmo_model.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions usl_models/tests/atmo_ml/atmo_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from usl_models.atmo_ml import constants
from usl_models.atmo_ml import model_params

_TEST_MAP_HEIGHT = 100
_TEST_MAP_WIDTH = 100
_TEST_SPATIAL_FEATURES = 17 # lu_index is now separate
_TEST_SPATIOTEMPORAL_FEATURES = 9
_TEST_MAP_HEIGHT = 200
_TEST_MAP_WIDTH = 200
_TEST_SPATIAL_FEATURES = 22 # lu_index is now separate
_TEST_SPATIOTEMPORAL_FEATURES = 12
_LU_INDEX_VOCAB_SIZE = 61
_EMBEDDING_DIM = 8

Expand Down
8 changes: 0 additions & 8 deletions usl_models/tests/atmo_ml/atmo_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def test_split_time_step_pairs():
expected_output = tf.constant(
[
[
[
[[0, 1], [4, 5]],
[[8, 9], [12, 13]],
],
[
[[2, 3], [6, 7]],
[[10, 11], [14, 15]],
Expand All @@ -82,10 +78,6 @@ def test_split_time_step_pairs():
[[32, 33], [36, 37]],
[[40, 41], [44, 45]],
],
[
[[34, 35], [38, 39]],
[[42, 43], [46, 47]],
],
],
]
)
Expand Down
105 changes: 105 additions & 0 deletions usl_models/tests/atmo_ml/dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import io

from unittest import mock
from unittest.mock import MagicMock

import numpy as np

import usl_models.testing
from usl_models.atmo_ml import dataset
from usl_models.atmo_ml import constants


def create_mock_blob(data, dtype=np.float32, allow_pickle=True):
"""Create a mock blob with simulated data and return it."""
blob = MagicMock()
buf = io.BytesIO()
np.save(buf, data.astype(dtype), allow_pickle=allow_pickle)
buf.seek(0)
blob.open.return_value = buf
return blob


class TestAtmoMLDataset(usl_models.testing.TestCase):
@mock.patch("google.cloud.storage.Client")
def test_load_dataset_structure(self, mock_storage_client):
"""Test creating AtmoML dataset from GCS with expected structure and shapes."""
# Mock GCS client and bucket
mock_storage_client_instance = mock_storage_client.return_value
mock_bucket = MagicMock()
mock_storage_client_instance.bucket.return_value = mock_bucket

num_days = 4
timesteps_per_day = 6
num_timesteps = num_days * timesteps_per_day
batch_size = 2

B = batch_size
H, W = constants.MAP_HEIGHT, constants.MAP_WIDTH
F_S = constants.NUM_SAPTIAL_FEATURES
F_ST = constants.NUM_SPATIOTEMPORAL_FEATURES
C = constants.OUTPUT_CHANNELS
T_I, T_O = constants.INPUT_TIME_STEPS, constants.OUTPUT_TIME_STEPS

# Simulate mock blobs for datasets
mock_spatial_blob = create_mock_blob(
np.random.rand(H, W, F_S).astype(np.float32)
)
mock_spatiotemporal_tensor = np.random.rand(H, W, F_ST).astype(np.float32)
mock_spatiotemporal_blobs = [
create_mock_blob(mock_spatiotemporal_tensor) for _ in range(num_timesteps)
]
mock_lu_index_blob = create_mock_blob(
np.random.randint(
low=0,
high=10,
size=(H, W),
).astype(np.int32)
)
mock_label_blobs = [
create_mock_blob(np.random.rand(H, W, C).astype(np.float32))
for _ in range(num_timesteps)
]

# Mock blob listing behavior to simulate folder structure
mock_bucket.list_blobs.side_effect = lambda prefix: {
"sim1/spatial": [mock_spatial_blob],
"sim1/spatiotemporal": mock_spatiotemporal_blobs,
"sim1/lu_index": [mock_lu_index_blob],
"sim1": mock_label_blobs,
}[prefix]

# Define bucket names and folder paths
data_bucket_name = "test-data-bucket"
label_bucket_name = "test-label-bucket"

# Call the function under test
ds = dataset.load_dataset(
data_bucket_name=data_bucket_name,
label_bucket_name=label_bucket_name,
sim_names=["sim1"],
timesteps_per_day=timesteps_per_day,
storage_client=mock_storage_client_instance,
)
ds = ds.batch(batch_size=batch_size)

inputs, labels = zip(*ds)
num_batches = num_days // batch_size
self.assertShapesRecursive(
list(inputs),
[
{
"spatiotemporal": (B, T_I, H, W, F_ST),
"spatial": (B, H, W, F_S),
"lu_index": (B, H, W),
}
]
* num_batches,
)
self.assertShapesRecursive(
list(labels),
[
(B, T_O, H, W, C),
]
* num_batches,
)
194 changes: 0 additions & 194 deletions usl_models/tests/atmo_ml/datasets_test.py

This file was deleted.

Loading