diff --git a/usl_models/tests/atmo_ml/dataset_test.py b/usl_models/tests/atmo_ml/dataset_test.py index cf830c71..ce6e6263 100644 --- a/usl_models/tests/atmo_ml/dataset_test.py +++ b/usl_models/tests/atmo_ml/dataset_test.py @@ -92,13 +92,15 @@ def test_load_dataset_structure(self, mock_storage_client): "spatial": (B, H, W, F_S), "lu_index": (B, H, W), } - ] * num_batches, + ] + * num_batches, ) self.assertShapesRecursive( list(labels), [ (B, T_O, H, W, C), - ] * num_batches, + ] + * num_batches, ) @mock.patch("google.cloud.storage.Client") diff --git a/usl_models/usl_models/testing/__init__.py b/usl_models/usl_models/testing/__init__.py index 7be765ad..c08a570c 100644 --- a/usl_models/usl_models/testing/__init__.py +++ b/usl_models/usl_models/testing/__init__.py @@ -7,6 +7,7 @@ class TestCase(unittest.TestCase): """Testing utils.""" + def assertShapesRecursive(self, obj: object, expected: object, path: str = ""): """Recursively checks the shapes of numpy arrays in a data structure.""" if isinstance(obj, np.ndarray) or tf.is_tensor(obj):