diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1d72cb2e26..3ed93d9ade 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -15,7 +15,7 @@ # Top level markdown files: # Keep CHANGELOG separate - don't require specific people to approve updates to it. -./CHANGELOG.md +./CHANGELOG.md ./*.md @ram-cherukuri @megnvidia .github/workflows @ktangsali @coreyjadams @nickgeneva @@ -57,17 +57,21 @@ physicsnemo/deploy physicsnemo/mesh @peterdsharpe # DATAPIPES +physicsnemo/datapipes/core @coreyjadams physicsnemo/datapipes/climate @dallasfoster @nickgeneva @pzharrington physicsnemo/datapipes/cae/domino_datapipe.py @RishikeshRanade @coreyjadams physicsnemo/datapipes/cae/mesh_datapipe.py @mnabian @Alexey-Kamenev physicsnemo/datapipes/cae/readers.py @mnabian physicsnemo/datapipes/gnn @mnabian @Alexey-Kamenev -physicsnemo/datapipes/healpix @pzharrington +physicsnemo/datapipes/healpix @pzharrington +examples/minimal/datapipes @coreyjadams +test/datapipes/core @coreyjadams + # Distributed tools physicsnemo/distributed @coreyjadams # MODEL SPECIFIC OWNERSHIP. These are in roughly the same order as the code repo. -# They are grouped "logically" together, though. +# They are grouped "logically" together, though. # For example, gnn_layers, graphcast, and meshgraphnet are together. # AFNO diff --git a/examples/minimal/datapipes/.gitignore b/examples/minimal/datapipes/.gitignore new file mode 100644 index 0000000000..1ecf5e7224 --- /dev/null +++ b/examples/minimal/datapipes/.gitignore @@ -0,0 +1,4 @@ +*.json +*.npz +*.npy +*.zarr diff --git a/examples/minimal/datapipes/README.md b/examples/minimal/datapipes/README.md new file mode 100644 index 0000000000..5f88cebf4f --- /dev/null +++ b/examples/minimal/datapipes/README.md @@ -0,0 +1,209 @@ +# PhysicsNeMo DataPipes + +Dataloading is critical to SciML applications, both for training and inference, +and the physicsnemo datapipe infrastructure aims to deliver a flexible and configurable +set of tools to enable your application. + +There are plenty of tools in the python eco system for loading, preprocessing, and +preparing your data for training or inference. To compare / contrast some of these +tools with the ecosystem available, and see if the physicsnemo datapipe interface +might be valuable to your workload, consider the following design principles +we followed when building the physicsnemo datapipes: + +1. **GPU-first** - Many scientific datasets are *large* for even a single example: +the data is high resolution and the preprocessing needs benefit from GPU acceleration. +Compare this to other methods where the data preprocessing is predominantly CPU-based, +such as the pytorch Dataloader: whereas CPU-based preprocessing may introduce GPU +pipeline stalls on high resolution data, GPU-based preprocessing will maximize +throughput. + +2. **Threading over Multiprocessing** - In python, true concurrency is typically only +available via multiprocessing or when offloading to compiled libraries or GPU kernels. +For this reason, many data loaders leverage multiprocessing for data concurrency: +load images in separate processes, and collate a batch on the main thread. +For simplicity, with a GPU-first paradigm, the physicsnemo datapipe focuses on GPU +concurrency via asynchronous execution and stream-based parallelism. IO is coordinated +in multiple threads, instead of multiple processes, and streams enable multiple +preprocessing pipelines to execute concurrently on the GPU. + +3. **Unambiguous Configuration and Serialization** - Datapipes can be a particularly +frustrating component in reproducibility of AI results - the preprocessing, sampling, +batching and other parameters can be hard to infer from training scripts. Here, +we make a deliberate design choice to enable datapipe configuration serialization +as a first-class citizen. PhysicsNeMo Datapipes can be built directly in Python, +but also instantiated from hydra yaml files for version control and distribution. + +4. **Familiar Interfaces** - We built our tools from scratch, but they are meant +to look familiar and inter-operate with the tools you already know. Use +physicsnemo DataLoaders as a replacement for PyTorch's Dataloader; tools like +DistributedSampler will still work. Users of `torchvision` will be familiar +with the concept of chaining transformations together. + +5. **Extensibility out of the box** - We want to provide a data pipeline that gives +great performance and usability immediately - but it will never be the case that +one codebase covers all possible data needs out of the box. Therefore, the +physicsnemo datapipe is extensible: you can build custom data readers for +new dataformats, and plug them in to datasets; you can build new transforms +for your data that we might not have, and simply plug them into a transformation +pipeline. You can even package all of this up as a pip-installable extension: Using +the built in registry enables you to still instantiate and version control datapipes, +when the components are not even part of PhysicsNeMo. + +## When should I use PhysicsNeMo datapipes over X/Y/Z data utility? + +In general, the physicsnemo datapipe utility is built to deliver good performance +on data that is large, per example, like most scientific data is. If you want a +batch size of 512 small images, it may be more performant to use a CPU-centric +tool. + +Another advantage of the PhysicsNeMo datapipe is the ability to build datapipes +directly from configuration files, allowing serializable and version-controlled +data configuration. This isn't the only tool that can do this, of course. + +## Core Datapipe Design + +Think of datasets as a hierarchy of data: at the highest level, an entire **dataset** +consists of independent **examples**. Each example has one or more **tensor components**: +image data may have input images and target labels; CFD data may have positions, +target pressures, a mesh object, boundary conditions, etc.; weather data may contain +sensor readings as a function of time. Each example may be the same size as the others, +or each example may be a unique size. Even the components of an example can be variable, +though this can require extra care in reading and using the dataset. + +The PhysicsNeMo datapipe consists of the following components: + +- `reader` objects contain the logic to understand a **dataset** on disk, and + load examples into CPU memory. + +- The `dataset` object, which contains a `reader`, orchestrates threads that preload + data **examples** from disk and move it to GPU. On the GPU, a `dataset` can apply a + series of transformations to each **example**. Each example is stored in `tensordict` + format. The dataset will also track metadata, for understanding where each **example** + came from (index, filepath, etc.). + +- A `transform` is a callable class that accepts a tensordict as input, and returns + a `tensordict` as output. Chaining transformations together is the core way to + manipulate data examples on the fly in a datapipe. + +- The `dataloader` is a drop-in replacement for the PyTorch DataLoader, with additional + optimizations for the GPU-centric processing here. The `dataloader` handles + stream concurrency, batch collation, and triggering preloading of datasets. + +--- + +## Tutorials + +This directory contains progressive tutorials that teach you how to use the +PhysicsNeMo datapipe infrastructure effectively. Note that some of the tutorials +are repetitive and verbose, to highlight different features of the datapipe +ecosystem. We'll give some overview of what you can learn in each tutorial, +but they are meant to be run interactively and explored. + +### Data Prerequisites + +You do not need to have any specific data in hand for the tutorials. You can +generate synthetic data with the scripts `generate_regular_data.py` and +`generate_variable_points_data.py`. + +### Tutorial 1: Getting Started with DataPipes + +**File:** `tutorial_01_getting_started.py` + +Learn the core concepts of data loading from disk: + +- Creating a Reader to load data from files +- Understanding the `(TensorDict, metadata)` return format +- Wrapping a reader in a Dataset +- Iterating with a DataLoader +- Accessing batch data via TensorDict keys + +```bash +# Generate tutorial data first +python generate_regular_data.py -n 100 \ +-s "velocity:128,128,128,3 pressure:128,128,128,1 position:128,128,128,3" \ +-b zarr -o output/tutorial_data/ + +# Run the tutorial +python tutorial_01_getting_started.py +``` + +### Tutorial 2: Transforms and Data Preprocessing + +**File:** `tutorial_02_transforms.py` + +Build preprocessing pipelines with transforms: + +- Apply a single transform (Normalize) +- Compose multiple transforms together +- Subsample point clouds with SubsamplePoints +- Use geometric transforms (Translate, ReScale) +- Save/load normalization statistics from files +- Denormalize data with the `inverse()` method + +```bash +# Generate regular grid data (for most sections) +# Note: Tutorial 2 can reuse the data from Tutorial 1 +python generate_regular_data.py -n 100 \ +-s "velocity:128,128,128,3 pressure:128,128,128,1 position:128,128,128,3" \ +-b zarr -o output/tutorial_data/ + +# Generate point cloud data (for subsampling sections) +python generate_variable_points_data.py -n 100 \ +-s "coords:3 features:8" --min-points 50000 \ +--max-points 100000 -b zarr -o output/pointcloud_data/ + +# Run the tutorial +python tutorial_02_transforms.py +``` + +### Tutorial 3: Custom Collation for GNNs + +**File:** `tutorial_03_custom_gnn_datapipe.py` + +Build a GNN-ready data pipeline with custom collation: + +- Build a custom Transform for computing KNN graph edges +- Implement a custom Collator for PyG-style graph batching +- Understand how PyG batches graphs (offset edges, concatenate features, batch tensor) +- Put it all together in a complete GNN training pipeline + +```bash +# Generate point cloud data with coordinates and features (can be reused from tutorial 2) +python generate_variable_points_data.py -n 100 \ +-s "coords:3 features:8" --min-points 50000 \ +--max-points 100000 -b zarr -o output/pointcloud_data/ + +# Run the tutorial +python tutorial_03_custom_gnn_datapipe.py +``` + +### Tutorial 4: Hydra Configuration for DataPipes + +**File:** `tutorial_04_hydra_config.py` + +Build entire datapipes from YAML configuration with minimal Python code: + +- Define reader, transforms, dataset, and dataloader in YAML +- Use `hydra.utils.instantiate()` to build components +- Override any parameter from the command line +- Switch between configurations easily + +```bash +# Generate tutorial data (from tutorials 2 and 3) +python generate_variable_points_data.py -n 100 -s \ +"coords:3 features:8" --min-points 50000 \ +--max-points 100000 -b zarr -o output/pointcloud_data/ + +# Run with default config +python tutorial_04_hydra_config.py + +# Override from command line +python tutorial_04_hydra_config.py dataloader.batch_size=8 dataloader.dataset.device=cuda + +# Use point cloud configuration (this is the default) +python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud + +# Override transform parameters +python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud \ + subsample.n_points=5000 +``` diff --git a/examples/minimal/datapipes/conf/config.yaml b/examples/minimal/datapipes/conf/config.yaml new file mode 100644 index 0000000000..48b25abba2 --- /dev/null +++ b/examples/minimal/datapipes/conf/config.yaml @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - reader/zarr \ No newline at end of file diff --git a/examples/minimal/datapipes/conf/reader/npz.yaml b/examples/minimal/datapipes/conf/reader/npz.yaml new file mode 100644 index 0000000000..b2340c62ce --- /dev/null +++ b/examples/minimal/datapipes/conf/reader/npz.yaml @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/minimal/datapipes/conf/reader/tensorstore_zarr.yaml b/examples/minimal/datapipes/conf/reader/tensorstore_zarr.yaml new file mode 100644 index 0000000000..487d1dc311 --- /dev/null +++ b/examples/minimal/datapipes/conf/reader/tensorstore_zarr.yaml @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TensorStore Zarr Reader Configuration +# High-performance async reader for Zarr files using TensorStore +_target_: physicsnemo.datapipes.core.readers.TensorStoreZarrReader +path: ??? +group_pattern: "*.zarr" +fields: null +default_values: null +cache_bytes_limit: 10000000 # 10 MB cache +data_copy_concurrency: 72 +file_io_concurrency: 72 +pin_memory: false +include_index_in_metadata: true + +# Coordinated subsampling configuration (optional) +# Reads contiguous chunks from large tensors to reduce IO bandwidth. +# Set to null to disable, or configure n_points and target_keys. +coordinated_subsampling: null +# Example: +# coordinated_subsampling: +# n_points: 10000 +# target_keys: +# - volume_coords +# - volume_fields diff --git a/examples/minimal/datapipes/conf/reader/zarr.yaml b/examples/minimal/datapipes/conf/reader/zarr.yaml new file mode 100644 index 0000000000..6ca4fb052c --- /dev/null +++ b/examples/minimal/datapipes/conf/reader/zarr.yaml @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# _target_: physicsnemo.datapipes.core.readers.TensorStoreZarrReader +_target_: physicsnemo.datapipes.core.readers.ZarrReader +path: ??? +fields: null +default_values: null +group_pattern: "*.zarr" +pin_memory: false +include_index_in_metadata: true + diff --git a/examples/minimal/datapipes/conf/transforms/normalize.yaml b/examples/minimal/datapipes/conf/transforms/normalize.yaml new file mode 100644 index 0000000000..18a5931c38 --- /dev/null +++ b/examples/minimal/datapipes/conf/transforms/normalize.yaml @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Normalize Transform Configuration +# Normalizes specified fields using mean-std or min-max scaling +# +# Method options: +# - mean_std: Applies (x - mean) / std +# - min_max: Applies (x - center) / half_range, normalizing to [-1, 1] +_target_: physicsnemo.datapipes.core.transforms.Normalize +_convert_: all +input_keys: + - features +method: mean_std # Required: "mean_std" or "min_max" +means: + features: 0.0 +stds: + features: 0.6 +eps: 1.0e-8 + +# For min_max method, use these instead of means/stds: +# mins: +# features: -1.0 +# maxs: +# features: 1.0 diff --git a/examples/minimal/datapipes/conf/transforms/subsample.yaml b/examples/minimal/datapipes/conf/transforms/subsample.yaml new file mode 100644 index 0000000000..4267bb0b19 --- /dev/null +++ b/examples/minimal/datapipes/conf/transforms/subsample.yaml @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SubsamplePoints Transform Configuration +# Subsamples point clouds to a fixed number of points +# Useful for handling variable-size point cloud data in batched training +_target_: physicsnemo.datapipes.core.transforms.SubsamplePoints +input_keys: + - coords + - features +n_points: 10000 +algorithm: uniform # Options: "uniform" or "poisson_fixed" +weights_key: null # Optional: key for weighted sampling (e.g., "surface_areas") diff --git a/examples/minimal/datapipes/conf/tutorial_04_pointcloud.yaml b/examples/minimal/datapipes/conf/tutorial_04_pointcloud.yaml new file mode 100644 index 0000000000..2696329def --- /dev/null +++ b/examples/minimal/datapipes/conf/tutorial_04_pointcloud.yaml @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# Tutorial 4: Point Cloud Pipeline Configuration (Modular) +# ============================================================================= +# +# This config demonstrates modular Hydra composition where: +# - Reader config is loaded from a separate file (conf/reader/) +# - Transform configs are loaded from separate files (conf/transforms/) +# - A single hydra.utils.instantiate() call builds the entire pipeline +# +# Run with: +# python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud +# +# Override from command line: +# python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud \ +# dataloader.batch_size=8 +# +# Switch reader (e.g., to NPZ): +# python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud \ +# reader=npz reader.path=./output/npz_data/ +# +# Override transform parameters: +# python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud \ +# subsample.n_points=5000 normalize.stds.features=1.0 +# +# ============================================================================= + +# ----------------------------------------------------------------------------- +# Hydra Defaults: Compose configs from modular files +# ----------------------------------------------------------------------------- +# Each default loads a config file into the specified key: +# - reader: zarr -> loads conf/reader/zarr.yaml into 'reader' key +# - transforms/subsample@subsample -> loads conf/transforms/subsample.yaml into 'subsample' key +# - transforms/normalize@normalize -> loads conf/transforms/normalize.yaml into 'normalize' key +# ----------------------------------------------------------------------------- +defaults: + - reader@reader: zarr + - transforms@subsample: subsample + - transforms@normalize: normalize + - _self_ + +# ----------------------------------------------------------------------------- +# Reader overrides: Customize the reader loaded from defaults +# ----------------------------------------------------------------------------- +reader: + path: ./output/pointcloud_data/ + +# ----------------------------------------------------------------------------- +# DataLoader: Top-level component that instantiates everything recursively +# ----------------------------------------------------------------------------- +dataloader: + _target_: physicsnemo.datapipes.core.DataLoader + batch_size: 4 + shuffle: true + drop_last: false + prefetch_factor: 2 + num_streams: 4 + use_streams: true + + # ------------------------------------------------------------------------- + # Dataset: Nested inside DataLoader + # ------------------------------------------------------------------------- + dataset: + _target_: physicsnemo.datapipes.core.Dataset + device: cuda + num_workers: 2 + + # Reader: Reference the config loaded from defaults + reader: ${reader} + + # Transforms: List referencing configs loaded from defaults + # Add or remove transforms by editing this list + transforms: + - ${subsample} + - ${normalize} + +# ----------------------------------------------------------------------------- +# Training settings +# ----------------------------------------------------------------------------- +training: + num_epochs: 2 + log_interval: 1 diff --git a/examples/minimal/datapipes/generate_regular_data.py b/examples/minimal/datapipes/generate_regular_data.py new file mode 100644 index 0000000000..87c27bb7c6 --- /dev/null +++ b/examples/minimal/datapipes/generate_regular_data.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Script to generate synthetic data with configurable shapes and storage backends. + +Supports .npz, and zarr storage formats. +""" + +import argparse +import json +from pathlib import Path +from typing import Dict, Tuple + +import numpy as np + + +def parse_shapes(shapes_str: str) -> Dict[str, Tuple[int, ...]]: + """ + Parse shape specification from command line. + + Expected format: "key1:dim1,dim2,dim3 key2:dim1,dim2" + Example: "velocity:100,64,64 pressure:100,32,32" + + Parameters + ---------- + shapes_str : str + Space-separated key:shape pairs where shape is comma-separated dimensions + + Returns + ------- + Dict[str, Tuple[int, ...]] + Dictionary mapping field names to shape tuples + """ + shapes = {} + for item in shapes_str.split(): + if ":" not in item: + raise ValueError( + f"Invalid shape specification: {item}. Expected format: key:dim1,dim2,..." + ) + + key, dims_str = item.split(":", 1) + try: + dims = tuple(int(d) for d in dims_str.split(",")) + except ValueError as e: + raise ValueError(f"Invalid dimensions for key '{key}': {dims_str}") from e + + shapes[key] = dims + + return shapes + + +def generate_synthetic_data( + num_samples: int, shapes: Dict[str, Tuple[int, ...]], seed: int = 42 +) -> Dict[str, np.ndarray]: + """ + Generate synthetic random data. + + Parameters + ---------- + num_samples : int + Number of samples to generate + shapes : Dict[str, Tuple[int, ...]] + Dictionary mapping field names to shape tuples (per sample) + seed : int, optional + Random seed for reproducibility + + Returns + ------- + Dict[str, np.ndarray] + Dictionary mapping field names to generated data arrays + Each array has shape (num_samples, *shape) + """ + rng = np.random.RandomState(seed) + data = {} + + for key, shape in shapes.items(): + full_shape = (num_samples,) + shape + # Generate random data in range [-1, 1] + data[key] = rng.uniform(-1.0, 1.0, size=full_shape).astype(np.float32) + print(f"Generated '{key}' with shape {full_shape}") + + return data + + +def save_npz(data: Dict[str, np.ndarray], output_dir: Path): + """ + Save data as separate .npz files per sample. + + Each sample is saved as an .npz file containing all fields. + + Parameters + ---------- + data : Dict[str, np.ndarray] + Dictionary of arrays to save, each with shape (num_samples, ...) + output_dir : Path + Output directory + """ + output_dir.mkdir(parents=True, exist_ok=True) + + # Get number of samples (assumes all fields have the same number) + num_samples = next(iter(data.values())).shape[0] + + for i in range(num_samples): + # Extract this sample from all fields + sample_data = {key: array[i] for key, array in data.items()} + + # Save to individual file + filepath = output_dir / f"sample_{i:06d}.npz" + np.savez(filepath, **sample_data) + + total_size = sum(array.nbytes for array in data.values()) + print( + f"Saved {num_samples} samples as .npz files ({total_size / 1e6:.2f} MB total)" + ) + + +def save_zarr(data: Dict[str, np.ndarray], output_dir: Path): + """ + Save data as separate zarr directories per sample. + + Each sample is saved in its own zarr directory containing all fields. + + Parameters + ---------- + data : Dict[str, np.ndarray] + Dictionary of arrays to save, each with shape (num_samples, ...) + output_dir : Path + Output directory + """ + try: + import zarr + except ImportError: + raise ImportError( + "zarr is required for zarr storage backend. Install with: pip install zarr" + ) + + output_dir.mkdir(parents=True, exist_ok=True) + + # Get number of samples (assumes all fields have the same number) + num_samples = next(iter(data.values())).shape[0] + + for i in range(num_samples): + # Create zarr group for this sample + sample_dir = output_dir / f"sample_{i:06d}.zarr" + # zarr v3 compatible: pass path directly instead of DirectoryStore + root = zarr.open_group(str(sample_dir), mode="w") + + # Save all fields for this sample + for key, array in data.items(): + sample_data = array[i] + root.create_dataset( + key, + data=sample_data, + shape=sample_data.shape, + ) + + total_size = sum(array.nbytes for array in data.values()) + print( + f"Saved {num_samples} samples as zarr directories ({total_size / 1e6:.2f} MB total)" + ) + + +def save_metadata( + output_dir: Path, num_samples: int, shapes: Dict[str, Tuple[int, ...]], backend: str +): + """ + Save metadata about the generated dataset. + + Parameters + ---------- + output_dir : Path + Output directory + num_samples : int + Number of samples generated + shapes : Dict[str, Tuple[int, ...]] + Shapes of the generated data + backend : str + Storage backend used + """ + metadata = { + "num_samples": num_samples, + "shapes": {k: list(v) for k, v in shapes.items()}, + "backend": backend, + } + + metadata_path = output_dir / "metadata.json" + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + print(f"Metadata saved to {metadata_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate synthetic data with configurable shapes and storage backends", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Generate 100 samples with two fields, save as .npz + python generate_data.py -n 100 -s "velocity:64,64,64 pressure:32,32,32" -b npz -o output/ + + # Generate 200 samples, save as zarr + python generate_data.py -n 200 -s "u:128,128 v:128,128" -b zarr -o zarr_data/ + """, + ) + + parser.add_argument( + "-n", + "--num-samples", + type=int, + required=True, + help="Number of samples to generate", + ) + + parser.add_argument( + "-s", + "--shapes", + type=str, + required=True, + help='Space-separated key:shape pairs (e.g., "field1:100,200 field2:64,64,64")', + ) + + parser.add_argument( + "-b", + "--backend", + type=str, + choices=["npz", "zarr"], + default="zarr", + help="Storage backend to use (default: zarr)", + ) + + parser.add_argument( + "-o", + "--output", + type=str, + default="synthetic_data", + help="Output directory (default: synthetic_data)", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility (default: 42)", + ) + + args = parser.parse_args() + + # Parse shapes + print("Parsing shape specifications...") + shapes = parse_shapes(args.shapes) + print(f"Shapes: {shapes}") + + # Generate data + print(f"\nGenerating {args.num_samples} samples...") + data = generate_synthetic_data(args.num_samples, shapes, seed=args.seed) + + # Save data + output_dir = Path(args.output) + print(f"\nSaving data to {output_dir} using backend '{args.backend}'...") + + if args.backend == "npz": + save_npz(data, output_dir) + elif args.backend == "zarr": + save_zarr(data, output_dir) + + # Save metadata + save_metadata(output_dir, args.num_samples, shapes, args.backend) + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/examples/minimal/datapipes/generate_variable_points_data.py b/examples/minimal/datapipes/generate_variable_points_data.py new file mode 100644 index 0000000000..22648e61b0 --- /dev/null +++ b/examples/minimal/datapipes/generate_variable_points_data.py @@ -0,0 +1,405 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Script to generate synthetic data with variable number of points per sample. + +Each sample has a semi-random number of points (between min_points and max_points), +while maintaining consistent feature dimensions across fields. + +Supports .npz, and zarr storage formats. +""" + +import argparse +import json +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np + + +def parse_shapes(shapes_str: str) -> Dict[str, Tuple[int, ...]]: + """ + Parse shape specification from command line. + + Expected format: "key1:dim1,dim2 key2:dim1" + Example: "velocity:3 pressure:1 temperature:1" + + The shape specifies the feature dimensions (excluding the points dimension). + For example, "velocity:3" means each point has 3 velocity components. + + Parameters + ---------- + shapes_str : str + Space-separated key:shape pairs where shape is comma-separated feature dimensions + + Returns + ------- + Dict[str, Tuple[int, ...]] + Dictionary mapping field names to feature dimension tuples + """ + shapes = {} + for item in shapes_str.split(): + if ":" not in item: + raise ValueError( + f"Invalid shape specification: {item}. Expected format: key:dim1,dim2,..." + ) + + key, dims_str = item.split(":", 1) + try: + dims = tuple(int(d) for d in dims_str.split(",")) + except ValueError as e: + raise ValueError(f"Invalid dimensions for key '{key}': {dims_str}") from e + + shapes[key] = dims + + return shapes + + +def generate_point_counts( + num_samples: int, min_points: int, max_points: int, seed: int = 42 +) -> np.ndarray: + """ + Generate random point counts for each sample. + + Parameters + ---------- + num_samples : int + Number of samples to generate point counts for + min_points : int + Minimum number of points per sample + max_points : int + Maximum number of points per sample + seed : int, optional + Random seed for reproducibility + + Returns + ------- + np.ndarray + Array of point counts with shape (num_samples,) + """ + rng = np.random.RandomState(seed) + point_counts = rng.randint(min_points, max_points + 1, size=num_samples) + return point_counts + + +def generate_variable_sample( + num_points: int, shapes: Dict[str, Tuple[int, ...]], rng: np.random.RandomState +) -> Dict[str, np.ndarray]: + """ + Generate a single sample with variable number of points. + + Parameters + ---------- + num_points : int + Number of points for this sample + shapes : Dict[str, Tuple[int, ...]] + Dictionary mapping field names to feature dimension tuples + rng : np.random.RandomState + Random number generator + + Returns + ------- + Dict[str, np.ndarray] + Dictionary mapping field names to generated data arrays + Each array has shape (num_points, *feature_dims) + """ + sample_data = {} + + for key, feature_dims in shapes.items(): + full_shape = (num_points,) + feature_dims + # Generate random data in range [-1, 1] + sample_data[key] = rng.uniform(-1.0, 1.0, size=full_shape).astype(np.float32) + + return sample_data + + +def save_npz( + num_samples: int, + point_counts: np.ndarray, + shapes: Dict[str, Tuple[int, ...]], + output_dir: Path, + seed: int, +): + """ + Save data as separate .npz files per sample with variable points. + + Each sample is saved as an .npz file containing all fields. + + Parameters + ---------- + num_samples : int + Number of samples to generate + point_counts : np.ndarray + Array of point counts for each sample + shapes : Dict[str, Tuple[int, ...]] + Dictionary of feature dimensions for each field + output_dir : Path + Output directory + seed : int + Random seed for reproducibility + """ + output_dir.mkdir(parents=True, exist_ok=True) + rng = np.random.RandomState(seed) + + total_size = 0 + for i in range(num_samples): + num_points = point_counts[i] + sample_data = generate_variable_sample(num_points, shapes, rng) + + # Save to individual file + filepath = output_dir / f"sample_{i:06d}.npz" + np.savez(filepath, **sample_data) + + # Track size + total_size += sum(array.nbytes for array in sample_data.values()) + + if (i + 1) % 10 == 0 or i == num_samples - 1: + print(f" Saved {i + 1}/{num_samples} samples...") + + print( + f"Saved {num_samples} samples as .npz files ({total_size / 1e6:.2f} MB total)" + ) + + +def save_zarr( + num_samples: int, + point_counts: np.ndarray, + shapes: Dict[str, Tuple[int, ...]], + output_dir: Path, + seed: int, +): + """ + Save data as separate zarr directories per sample with variable points. + + Each sample is saved in its own zarr directory containing all fields. + + Parameters + ---------- + num_samples : int + Number of samples to generate + point_counts : np.ndarray + Array of point counts for each sample + shapes : Dict[str, Tuple[int, ...]] + Dictionary of feature dimensions for each field + output_dir : Path + Output directory + seed : int + Random seed for reproducibility + """ + try: + import zarr + except ImportError: + raise ImportError( + "zarr is required for zarr storage backend. Install with: pip install zarr" + ) + + output_dir.mkdir(parents=True, exist_ok=True) + rng = np.random.RandomState(seed) + + total_size = 0 + for i in range(num_samples): + num_points = point_counts[i] + sample_data = generate_variable_sample(num_points, shapes, rng) + + # Create zarr group for this sample + sample_dir = output_dir / f"sample_{i:06d}.zarr" + root = zarr.open_group(str(sample_dir), mode="w") + + # Save all fields for this sample + for key, array in sample_data.items(): + root.create_dataset( + key, + data=array, + shape=array.shape, + ) + + # Track size + total_size += sum(array.nbytes for array in sample_data.values()) + + if (i + 1) % 10 == 0 or i == num_samples - 1: + print(f" Saved {i + 1}/{num_samples} samples...") + + print( + f"Saved {num_samples} samples as zarr directories ({total_size / 1e6:.2f} MB total)" + ) + + +def save_metadata( + output_dir: Path, + num_samples: int, + shapes: Dict[str, Tuple[int, ...]], + point_counts: np.ndarray, + min_points: int, + max_points: int, + backend: str, +): + """ + Save metadata about the generated dataset. + + Parameters + ---------- + output_dir : Path + Output directory + num_samples : int + Number of samples generated + shapes : Dict[str, Tuple[int, ...]] + Feature dimensions of each field + point_counts : np.ndarray + Array of point counts for each sample + min_points : int + Minimum number of points + max_points : int + Maximum number of points + backend : str + Storage backend used + """ + metadata = { + "num_samples": num_samples, + "feature_shapes": {k: list(v) for k, v in shapes.items()}, + "point_counts": point_counts.tolist(), + "min_points": int(min_points), + "max_points": int(max_points), + "mean_points": float(np.mean(point_counts)), + "backend": backend, + "variable_points": True, + } + + metadata_path = output_dir / "metadata.json" + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + print(f"Metadata saved to {metadata_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate synthetic data with variable number of points per sample", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Generate 100 samples with variable points (30k-100k), 3D velocity and scalar pressure + python generate_variable_points_data.py -n 100 -s "velocity:3 pressure:1" -b npz -o output/ + + # Generate 50 samples with 10k-50k points + python generate_variable_points_data.py -n 50 -s "coords:3 features:8" --min-points 10000 --max-points 50000 -b zarr + + # Generate point cloud data with normals + python generate_variable_points_data.py -n 200 -s "xyz:3 normal:3 color:3" -b zarr -o pointcloud_data/ + """, + ) + + parser.add_argument( + "-n", + "--num-samples", + type=int, + required=True, + help="Number of samples to generate", + ) + + parser.add_argument( + "-s", + "--shapes", + type=str, + required=True, + help='Space-separated key:shape pairs for feature dimensions (e.g., "velocity:3 pressure:1")', + ) + + parser.add_argument( + "--min-points", + type=int, + default=30000, + help="Minimum number of points per sample (default: 30000)", + ) + + parser.add_argument( + "--max-points", + type=int, + default=100000, + help="Maximum number of points per sample (default: 100000)", + ) + + parser.add_argument( + "-b", + "--backend", + type=str, + choices=["npz", "zarr"], + default="npz", + help="Storage backend to use (default: npz)", + ) + + parser.add_argument( + "-o", + "--output", + type=str, + default="synthetic_variable_data", + help="Output directory (default: synthetic_variable_data)", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility (default: 42)", + ) + + args = parser.parse_args() + + # Validate point range + if args.min_points >= args.max_points: + raise ValueError( + f"min_points ({args.min_points}) must be less than max_points ({args.max_points})" + ) + + # Parse shapes + print("Parsing shape specifications...") + shapes = parse_shapes(args.shapes) + print(f"Feature shapes: {shapes}") + print(f"Point range: {args.min_points:,} to {args.max_points:,}") + + # Generate point counts for each sample + print(f"\nGenerating point counts for {args.num_samples} samples...") + point_counts = generate_point_counts( + args.num_samples, args.min_points, args.max_points, seed=args.seed + ) + print(f"Mean points per sample: {np.mean(point_counts):.0f}") + print(f"Total points across all samples: {np.sum(point_counts):,}") + + # Save data + output_dir = Path(args.output) + print(f"\nSaving data to {output_dir} using backend '{args.backend}'...") + + if args.backend == "npz": + save_npz(args.num_samples, point_counts, shapes, output_dir, args.seed) + elif args.backend == "zarr": + save_zarr(args.num_samples, point_counts, shapes, output_dir, args.seed) + + # Save metadata + save_metadata( + output_dir, + args.num_samples, + shapes, + point_counts, + args.min_points, + args.max_points, + args.backend, + ) + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/examples/minimal/datapipes/tutorial_01_getting_started.py b/examples/minimal/datapipes/tutorial_01_getting_started.py new file mode 100644 index 0000000000..2c64f78b1d --- /dev/null +++ b/examples/minimal/datapipes/tutorial_01_getting_started.py @@ -0,0 +1,447 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +from pathlib import Path + +import torch + +# Import the core datapipe components +from physicsnemo.datapipes.core import ( + DataLoader, + Dataset, +) +from physicsnemo.datapipes.core.readers import ZarrReader + +""" +Tutorial 1: Getting Started with PhysicsNemo DataPipes +====================================================== + +This tutorial introduces the core concepts of PhysicsNemo's data loading +infrastructure. You'll learn how to: + +1. Create a Reader to load data from files +2. Understand the (TensorDict, metadata) return format +3. Wrap a reader in a Dataset +4. Iterate with a DataLoader +5. Access batch data via TensorDict keys + +Prerequisites +------------- +Before running this tutorial, generate some synthetic data: +""" + +# Generate 100 samples with velocity, pressure, and position fields +gen_cmd = 'python generate_regular_data.py -n 100 -s "velocity:128,128,128,3 pressure:128,128,128,1 position:128,128,128,3" -b zarr -o output/tutorial_data/' + +""" +This creates a directory structure like: + output/tutorial_data/ + ├── sample_000000.zarr/ + ├── sample_000001.zarr/ + ├── ... + └── metadata.json + +Run this tutorial: + python tutorial_01_getting_started.py + +Key Concepts +------------ +- **Reader**: Loads raw data from storage (HDF5, Zarr, NumPy, etc.) +- **TensorDict**: A dictionary-like container for named tensors +- **Dataset**: Combines a Reader with transforms and handles device transfer +- **DataLoader**: Batches samples and manages prefetching for efficiency +""" + + +def check_data_exists(data_path: str) -> bool: + """Check if tutorial data exists and provide helpful message if not.""" + path = Path(data_path) + if not path.exists(): + print(f"ERROR: Data not found at '{data_path}'") + print() + print("Please generate tutorial data first:") + print() + print(gen_cmd) + print() + return False + return True + + +# ============================================================================= +# Section 1: Creating a Reader +# ============================================================================= +def section_1_reader_basics(): + """ + Section 1: Creating Your First Reader + + Readers are the foundation of the datapipe system. They handle loading + data from various file formats and converting it to PyTorch tensors. + + PhysicsNemo provides several built-in readers: + - ZarrReader: For Zarr arrays (chunked, compressed storage) + - HDF5Reader: For HDF5 files + - NumpyReader: For .npy/.npz files + - VTKReader: For VTK mesh files + """ + print("=" * 70) + print("Section 1: Creating Your First Reader") + print("=" * 70) + print() + + # Path to our tutorial data + data_path = "./output/tutorial_data/" + + if not check_data_exists(data_path): + return None + + # Create a ZarrReader + # The reader automatically discovers all .zarr files in the directory + reader = ZarrReader( + path=data_path, + group_pattern="*.zarr", # Match files ending in .zarr + ) + + print(f"Created reader: {reader}") + print(f"Number of samples: {len(reader)}") + print(f"Field names: {reader.field_names}") + print() + + # Let's load a single sample directly from the reader + print("Loading sample 0 directly from reader...") + data, metadata = reader[0] + + print(f"Data type: {type(data)}") + print(f"Metadata: {metadata}") + print() + + # Examine the TensorDict contents + print("TensorDict contents:") + for key in data.keys(): + tensor = data[key] + print( + f" '{key}': shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}" + ) + + print() + return reader + + +# ============================================================================= +# Section 2: Understanding TensorDict +# ============================================================================= +def section_2_tensordict_basics(reader): + """ + Section 2: Understanding the (TensorDict, metadata) Format + + Every Reader returns a tuple of (TensorDict, metadata): + + - TensorDict: A dictionary-like container holding named tensors + - Access tensors by key: data["velocity"], data["pressure"] + - Supports batch operations, device transfers, and more + - From the tensordict library (PyTorch ecosystem) + + - metadata: A regular Python dict with non-tensor information + - Source file paths, sample indices, etc. + - Useful for debugging and tracking data provenance + """ + print("=" * 70) + print("Section 2: Understanding TensorDict") + print("=" * 70) + print() + + if reader is None: + print("Skipping - reader not available") + return + + # Load a sample + data, metadata = reader[0] + + # TensorDict acts like a dictionary + print("Accessing data like a dictionary:") + print(f" data['velocity'].shape = {data['velocity'].shape}") + print(f" data['pressure'].shape = {data['pressure'].shape}") + print() + + # You can iterate over keys + print("Iterating over TensorDict:") + for key, value in data.items(): + print(f" {key}: {value.shape}") + print() + + # TensorDict supports device transfers + print("Device operations:") + print(f" Current device: {data.device}") + + if torch.cuda.is_available(): + data_gpu = data.to("cuda") + print(f" After .to('cuda'): {data_gpu.device}") + print(f" data_gpu['velocity'].device = {data_gpu['velocity'].device}") + else: + print(" (CUDA not available - skipping GPU transfer demo)") + print() + + # Metadata contains non-tensor information + print("Metadata contents:") + for key, value in metadata.items(): + print(f" '{key}': {value}") + print() + + +# ============================================================================= +# Section 3: Wrapping Reader in Dataset +# ============================================================================= +def section_3_dataset_basics(): + """ + Section 3: Wrapping a Reader in a Dataset + + The Dataset class wraps a Reader and adds: + - Transform pipeline support (covered in Tutorial 2) + - Automatic device transfer (move data to GPU) + - Prefetching capabilities for performance + + Dataset is the recommended way to access data for training. + """ + print("=" * 70) + print("Section 3: Wrapping Reader in Dataset") + print("=" * 70) + print() + + data_path = "./output/tutorial_data/" + if not check_data_exists(data_path): + return None + + # Create reader + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + + # Wrap in Dataset - simplest case, no transforms + dataset = Dataset(reader=reader) + + print(f"Dataset: {dataset}") + print(f"Length: {len(dataset)}") + print() + + # Access samples via indexing (same as reader, but through dataset) + print("Accessing samples through Dataset:") + data, metadata = dataset[0] + print(f" Sample 0 keys: {list(data.keys())}") + print() + + # Dataset supports automatic GPU transfer! + if torch.cuda.is_available(): + print("Creating Dataset with automatic GPU transfer:") + dataset_gpu = Dataset(reader=reader, device="cuda") + + data_gpu, _ = dataset_gpu[0] + print(f" Data device: {data_gpu.device}") + print(f" velocity device: {data_gpu['velocity'].device}") + + # Clean up + dataset_gpu.close() + else: + print("(CUDA not available - skipping GPU dataset demo)") + print() + + return dataset + + +# ============================================================================= +# Section 4: Using the DataLoader +# ============================================================================= +def section_4_dataloader_basics(dataset): + """ + Section 4: Iterating with a DataLoader + + The DataLoader provides batched iteration over a Dataset: + - Batches multiple samples together + - Supports shuffling + - Manages prefetching with CUDA streams for performance + - Compatible with PyTorch's DistributedSampler + + This is the typical interface for training loops. + """ + print("=" * 70) + print("Section 4: Iterating with DataLoader") + print("=" * 70) + print() + + if dataset is None: + print("Skipping - dataset not available") + return + + # Create a DataLoader with batch_size=4 + dataloader = DataLoader( + dataset=dataset, + batch_size=4, + shuffle=True, # Shuffle samples each epoch + ) + + print(f"DataLoader batch_size: {dataloader.batch_size}") + print(f"Number of batches: {len(dataloader)}") + print() + + # Iterate over batches + print("Iterating over batches:") + for batch_idx, batch_data in enumerate(dataloader): + print(f"\nBatch {batch_idx}:") + print(f" Batch data type: {type(batch_data)}") + + for key in batch_data.keys(): + tensor = batch_data[key] + # Note: batch dimension is added as first dimension + print(f" '{key}': shape={tensor.shape}") + + # Just show first 2 batches for brevity + if batch_idx >= 1: + print("\n ... (showing only first 2 batches)") + break + + print() + + +# ============================================================================= +# Section 5: Putting It All Together +# ============================================================================= +def section_5_training_loop_example(): + """ + Section 5: A Simple Training Loop Example + + This section shows how datapipes fit into a typical training workflow. + We'll create a mock training loop that demonstrates: + - Loading batches of data + - Accessing specific fields for model input/output + - Basic timing for performance awareness + """ + print("=" * 70) + print("Section 5: Training Loop Example") + print("=" * 70) + print() + + data_path = "./output/tutorial_data/" + if not check_data_exists(data_path): + return + + # Setup: Reader -> Dataset -> DataLoader + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + + # For GPU training, specify device="cuda" + device = "cuda" if torch.cuda.is_available() else "cpu" + dataset = Dataset(reader=reader, device=device) + + dataloader = DataLoader( + dataset=dataset, + batch_size=4, + shuffle=True, + drop_last=True, # Drop incomplete final batch + ) + + print(f"Training on device: {device}") + print(f"Samples: {len(dataset)}, Batches per epoch: {len(dataloader)}") + print() + + # Mock training loop + print("Mock training loop (2 epochs):") + num_epochs = 2 + + for epoch in range(num_epochs): + epoch_start = time.time() + + for batch_idx, batch_data in enumerate(dataloader): + # In a real training loop, you would: + # 1. Extract inputs and targets + velocity = batch_data["velocity"] # Input features + pressure = batch_data["pressure"] # Target to predict + + # 2. Forward pass through model + # output = model(velocity) + + # 3. Compute loss + # loss = criterion(output, pressure) + + # 4. Backward pass and optimize + # loss.backward() + # optimizer.step() + + # For demo, just print shapes + if batch_idx == 0: + print( + f" Epoch {epoch}: velocity {velocity.shape}, " + f"pressure {pressure.shape}, device={velocity.device}" + ) + + epoch_time = time.time() - epoch_start + print(f" Epoch {epoch} completed in {epoch_time:.3f}s") + + print() + print("Training complete!") + print() + + # Clean up + dataset.close() + + +# ============================================================================= +# Main +# ============================================================================= +def main(): + """Run all tutorial sections.""" + print() + print("╔" + "═" * 68 + "╗") + print( + "║" + + " Tutorial 1: Getting Started with PhysicsNemo DataPipes ".center(68) + + "║" + ) + print("╚" + "═" * 68 + "╝") + print() + + # Section 1: Reader basics + reader = section_1_reader_basics() + + # Section 2: TensorDict format + section_2_tensordict_basics(reader) + + # Section 3: Dataset wrapper + dataset = section_3_dataset_basics() + + # Section 4: DataLoader iteration + section_4_dataloader_basics(dataset) + + # Section 5: Training loop example + section_5_training_loop_example() + + # Cleanup + if reader is not None: + reader.close() + if dataset is not None: + dataset.close() + + print("=" * 70) + print("Tutorial 1 Complete!") + print() + print("Key takeaways:") + print(" 1. Readers load raw data and return (TensorDict, metadata) tuples") + print(" 2. TensorDict is a dictionary-like container for named tensors") + print(" 3. Dataset wraps Reader + transforms + automatic device transfer") + print(" 4. DataLoader provides batched iteration for training") + print() + print("Next: Tutorial 2 - Transforms and Data Preprocessing") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/minimal/datapipes/tutorial_02_transforms.py b/examples/minimal/datapipes/tutorial_02_transforms.py new file mode 100644 index 0000000000..97cf8c1c58 --- /dev/null +++ b/examples/minimal/datapipes/tutorial_02_transforms.py @@ -0,0 +1,752 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +import torch + +# Import core datapipe components +from physicsnemo.datapipes.core import DataLoader, Dataset +from physicsnemo.datapipes.core.readers import ZarrReader + +# Import transforms +from physicsnemo.datapipes.core.transforms import ( + Compose, + Normalize, + ReScale, + SubsamplePoints, + Translate, +) + +""" +Tutorial 2: Transforms and Data Preprocessing +============================================== + +This tutorial covers the transform system in PhysicsNemo DataPipes. +You'll learn how to: + +1. Apply a single transform (Normalize) +2. Compose multiple transforms together +3. Subsample point clouds with SubsamplePoints +4. Use geometric transforms (Translate, ReScale) +5. Save/load normalization statistics from files +6. Denormalize data with the inverse() method + +Prerequisites +------------- +Before running this tutorial, generate some synthetic data: + +""" +# For regular grid data (Sections 1-2, 5-6): +gen_cmd_regular = """python generate_regular_data.py -n 100 -s "velocity:128,128,128,3 pressure:128,128,128,1 position:128,128,128,3" -b zarr -o output/tutorial_data/""" + +# For point cloud data (Sections 3-4, 7): +gen_cmd_cloud = """python generate_variable_points_data.py -n 100 -s "coords:3 features:8" --min-points 50000 --max-points 100000 -b zarr -o output/pointcloud_data/""" + +""" +Run this tutorial: + python tutorial_02_transforms.py + +Key Concepts +------------ +- **Transform**: An operation that takes a TensorDict and returns a modified TensorDict +- **Compose**: Chains multiple transforms into a pipeline +- **input_keys**: Most transforms specify which fields to operate on +- **state_dict()**: Transforms can be serialized for reproducibility +""" + + +def check_data_exists(data_path: str, generation_command: str) -> bool: + """Check if tutorial data exists and provide helpful message if not.""" + path = Path(data_path) + if not path.exists(): + print(f"ERROR: Data not found at '{data_path}'") + print() + print("Please generate tutorial data first:") + print() + print(f" {generation_command}") + print() + return False + return True + + +# ============================================================================= +# Section 1: Single Transform - Normalize +# ============================================================================= +def section_1_single_transform(): + """ + Section 1: Applying a Single Transform (Normalize) + + The Normalize transform standardizes tensor values. It supports two methods: + - mean_std: (x - mean) / std + - min_max: scales to [-1, 1] range + + Transforms operate on TensorDict objects and return modified TensorDicts. + """ + print("=" * 70) + print("Section 1: Single Transform - Normalize") + print("=" * 70) + print() + + data_path = "./output/tutorial_data/" + if not check_data_exists(data_path, gen_cmd_regular): + return None + + # Create reader and load a sample + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + data, metadata = reader[0] + + print("Before normalization:") + print( + f" velocity: mean={data['velocity'].mean():.4f}, std={data['velocity'].std():.4f}" + ) + print( + f" pressure: mean={data['pressure'].mean():.4f}, std={data['pressure'].std():.4f}" + ) + print() + + # Create a Normalize transform + # This will subtract mean and divide by std for specified keys + normalize = Normalize( + input_keys=["velocity", "pressure"], + method="mean_std", + # For real data, you'd compute these from your training set + means={"velocity": 0.0, "pressure": 0.0}, + stds={"velocity": 0.6, "pressure": 0.6}, # ~std of uniform[-1,1] + ) + + print(f"Transform: {normalize}") + print() + + # Apply the transform + normalized_data = normalize(data) + + print("After normalization:") + print( + f" velocity: mean={normalized_data['velocity'].mean():.4f}, std={normalized_data['velocity'].std():.4f}" + ) + print( + f" pressure: mean={normalized_data['pressure'].mean():.4f}, std={normalized_data['pressure'].std():.4f}" + ) + print() + + # Demonstrate min-max normalization + print("Min-Max normalization example:") + normalize_minmax = Normalize( + input_keys=["velocity"], + method="min_max", + mins={"velocity": -1.0}, + maxs={"velocity": 1.0}, + ) + + minmax_data = normalize_minmax(data) + print( + f" velocity range: [{minmax_data['velocity'].min():.4f}, {minmax_data['velocity'].max():.4f}]" + ) + print() + + reader.close() + return normalize + + +# ============================================================================= +# Section 2: Composing Multiple Transforms +# ============================================================================= +def section_2_compose_transforms(): + """ + Section 2: Composing Multiple Transforms + + The Compose class chains multiple transforms together, applying them + in sequence. This is similar to torchvision.transforms.Compose. + + Transform pipelines are the recommended way to build preprocessing. + """ + print("=" * 70) + print("Section 2: Composing Multiple Transforms") + print("=" * 70) + print() + + data_path = "./output/tutorial_data/" + if not check_data_exists(data_path, gen_cmd_regular): + return None + + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + + # Create multiple transforms + normalize_velocity = Normalize( + input_keys=["velocity"], + method="mean_std", + means={"velocity": 0.0}, + stds={"velocity": 0.6}, + ) + + normalize_pressure = Normalize( + input_keys=["pressure"], + method="mean_std", + means={"pressure": 0.0}, + stds={"pressure": 0.6}, + ) + + # Compose them into a pipeline + transform_pipeline = Compose( + [ + normalize_velocity, + normalize_pressure, + ] + ) + + print(f"Transform pipeline:\n{transform_pipeline}") + print() + + # Apply pipeline to data + data, _ = reader[0] + + print("Before pipeline:") + print(f" velocity std: {data['velocity'].std():.4f}") + print(f" pressure std: {data['pressure'].std():.4f}") + + transformed_data = transform_pipeline(data) + + print("After pipeline:") + print(f" velocity std: {transformed_data['velocity'].std():.4f}") + print(f" pressure std: {transformed_data['pressure'].std():.4f}") + print() + + # Better approach: Use transforms directly with Dataset + print("Using transforms with Dataset (recommended approach):") + + dataset = Dataset( + reader=reader, + transforms=[normalize_velocity, normalize_pressure], + ) + + data, _ = dataset[0] + print(f" velocity std: {data['velocity'].std():.4f}") + print(f" pressure std: {data['pressure'].std():.4f}") + print() + + dataset.close() + return transform_pipeline + + +# ============================================================================= +# Section 3: Point Cloud Subsampling +# ============================================================================= +def section_3_subsampling(): + """ + Section 3: Point Cloud Subsampling with SubsamplePoints + + Scientific data often involves large point clouds (meshes, particles). + SubsamplePoints efficiently downsamples while maintaining correspondence + between related fields (coordinates, features, normals, etc.). + + Supports: + - Uniform random sampling + - Poisson disk sampling (for very large datasets) + - Weighted sampling (e.g., area-weighted for surfaces) + """ + print("=" * 70) + print("Section 3: Point Cloud Subsampling") + print("=" * 70) + print() + + data_path = "./output/pointcloud_data/" + if not check_data_exists(data_path, gen_cmd_cloud): + return None + + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + + # Load a sample to see its original size + data, metadata = reader[0] + + print("Original point cloud:") + print(f" coords shape: {data['coords'].shape}") + print(f" features shape: {data['features'].shape}") + print() + + # Create a SubsamplePoints transform + # This samples the same indices from both coords and features + subsample = SubsamplePoints( + input_keys=["coords", "features"], # Keys to subsample together + n_points=10000, # Target number of points + algorithm="uniform", # or "poisson_fixed" for very large data + ) + # Note: the subsampling will assume a consistent leading dimension for all + # its input keys: so, it will generate an index of shape [n_points] and slice + # all input_keys in the same way. + + print(f"Transform: {subsample}") + print() + + # Apply subsampling + subsampled_data = subsample(data) + + print("After subsampling:") + print(f" coords shape: {subsampled_data['coords'].shape}") + print(f" features shape: {subsampled_data['features'].shape}") + print() + + # Use with Dataset for full pipeline + print("Using SubsamplePoints in a Dataset:") + + dataset = Dataset( + reader=reader, + transforms=[subsample], + ) + + # Iterate over a few samples + for i in range(3): + data, _ = dataset[i] + print( + f" Sample {i}: coords {data['coords'].shape}, features {data['features'].shape}" + ) + + print() + dataset.close() + return subsample + + +# ============================================================================= +# Section 4: Geometric Transforms +# ============================================================================= +def section_4_geometric_transforms(): + """ + Section 4: Geometric Transforms (Translate, ReScale) + + PhysicsNemo provides geometric transforms useful for point clouds and meshes: + - Translate: Shift coordinates by a fixed offset + - ReScale: Scale coordinates by a factor + + These are commonly used for data augmentation or centering data. + """ + print("=" * 70) + print("Section 4: Geometric Transforms") + print("=" * 70) + print() + + data_path = "./output/pointcloud_data/" + if not check_data_exists(data_path, gen_cmd_cloud): + return None + + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + data, _ = reader[0] + + # Original statistics + print("Original coordinates:") + coords = data["coords"] + print( + f" Mean: [{coords[:, 0].mean():.4f}, {coords[:, 1].mean():.4f}, {coords[:, 2].mean():.4f}]" + ) + print(f" Min: [{coords.min():.4f}]") + print(f" Max: [{coords.max():.4f}]") + print() + + # Translate: shift coordinates by subtracting a center point + # center_key_or_value can be a tensor or a key name referencing a tensor in the data + translate = Translate( + input_keys=["coords"], + center_key_or_value=torch.tensor( + [-0.5, -0.5, -0.5] + ), # Subtract this (shifts by +0.5) + ) + + translated_data = translate(data) + t_coords = translated_data["coords"] + + print("After Translate([0.5, 0.5, 0.5]):") + print( + f" Mean: [{t_coords[:, 0].mean():.4f}, {t_coords[:, 1].mean():.4f}, {t_coords[:, 2].mean():.4f}]" + ) + print() + + # ReScale: scale coordinates by dividing by a reference scale + # To scale UP by 2x, divide by 0.5 + rescale = ReScale( + input_keys=["coords"], + reference_scale=torch.tensor([0.5, 0.5, 0.5]), # Divide by this (scales by 2x) + ) + + rescaled_data = rescale(data) + r_coords = rescaled_data["coords"] + + print("After ReScale(2.0):") + print(f" Min: [{r_coords.min():.4f}]") + print(f" Max: [{r_coords.max():.4f}]") + print() + + # Compose geometric transforms with other transforms + print("Complete preprocessing pipeline:") + + # First subsample, then center (translate), then scale + pipeline = Compose( + [ + SubsamplePoints(input_keys=["coords", "features"], n_points=5000), + Translate( + input_keys=["coords"], + center_key_or_value=torch.tensor( + [0.0, 0.0, 0.0] + ), # Subtract origin (no-op here) + ), + ReScale( + input_keys=["coords"], + reference_scale=torch.tensor([0.5, 0.5, 0.5]), # Scale up by 2x + ), + ] + ) + + processed_data = pipeline(data) + print(f" Final coords shape: {processed_data['coords'].shape}") + print(f" Final features shape: {processed_data['features'].shape}") + print() + + reader.close() + + +# ============================================================================= +# Section 5: Saving and Loading Normalization Statistics +# ============================================================================= +def section_5_stats_serialization(): + """ + Section 5: Saving/Loading Normalization Statistics + + For reproducibility, you can save normalization statistics to files + and load them later. This is essential for: + - Using the same normalization at training and inference time + - Sharing preprocessing configs across experiments + """ + print("=" * 70) + print("Section 5: Saving/Loading Normalization Statistics") + print("=" * 70) + print() + + data_path = "./output/tutorial_data/" + if not check_data_exists(data_path, gen_cmd_regular): + return None + + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + + # In practice, compute statistics from your training data + print("Step 1: Compute statistics from training data") + print(" (In practice, iterate over all samples to compute mean/std)") + + # For demo, we'll use known values for uniform[-1,1] data + velocity_mean = 0.0 + velocity_std = 0.58 # Approximately sqrt(1/3) for uniform[-1,1] + pressure_mean = 0.0 + pressure_std = 0.58 + + print(f" velocity: mean={velocity_mean}, std={velocity_std}") + print(f" pressure: mean={pressure_mean}, std={pressure_std}") + print() + + # Create a temporary directory for saving stats + with TemporaryDirectory() as tmpdir: + stats_file = Path(tmpdir) / "normalization_stats.npz" + + # Step 2: Save statistics to .npz file + print(f"Step 2: Save statistics to {stats_file.name}") + + # The file format expected by Normalize.load_stats_from_npz: + # Each field maps to a dict with 'mean', 'std', 'min', 'max' keys + np.savez( + stats_file, + velocity={"mean": np.array(velocity_mean), "std": np.array(velocity_std)}, + pressure={"mean": np.array(pressure_mean), "std": np.array(pressure_std)}, + ) + print(" Stats saved!") + print() + + # Step 3: Load statistics when creating transform + print("Step 3: Create Normalize transform from stats file") + + normalize = Normalize( + input_keys=["velocity", "pressure"], + method="mean_std", + stats_file=str(stats_file), + ) + + print(f" Loaded transform: {normalize}") + print() + + # Verify it works + data, _ = reader[0] + normalized = normalize(data) + print("Step 4: Verify normalization") + print(f" velocity std after normalization: {normalized['velocity'].std():.4f}") + print(f" pressure std after normalization: {normalized['pressure'].std():.4f}") + print() + + # Alternative: Use state_dict() for serialization + print("Alternative: Using state_dict() for serialization") + + normalize = Normalize( + input_keys=["velocity"], + method="mean_std", + means={"velocity": 0.0}, + stds={"velocity": 0.58}, + ) + + state = normalize.state_dict() + print(f" state_dict keys: {list(state.keys())}") + print() + + # Create a new transform and load the state + new_normalize = Normalize( + input_keys=["velocity"], + method="mean_std", + means={"velocity": 999.0}, # Placeholder + stds={"velocity": 999.0}, + ) + new_normalize.load_state_dict(state) + + print(" Loaded state into new transform ✓") + print() + + reader.close() + + +# ============================================================================= +# Section 6: Denormalization with inverse() +# ============================================================================= +def section_6_inverse_normalization(): + """ + Section 6: Denormalization with the inverse() Method + + After your model makes predictions, you often need to convert back + to physical units. The Normalize transform provides an inverse() + method for this purpose. + """ + print("=" * 70) + print("Section 6: Denormalization with inverse()") + print("=" * 70) + print() + + data_path = "./output/tutorial_data/" + if not check_data_exists(data_path, gen_cmd_regular): + return None + + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + data, _ = reader[0] + + print("Original data statistics:") + print(f" pressure mean: {data['pressure'].mean():.4f}") + print(f" pressure std: {data['pressure'].std():.4f}") + print(f" pressure min: {data['pressure'].min():.4f}") + print(f" pressure max: {data['pressure'].max():.4f}") + print() + + # Create normalizer + normalize = Normalize( + input_keys=["pressure"], + method="mean_std", + means={"pressure": 0.0}, + stds={"pressure": 0.58}, + ) + + # Forward: normalize + normalized_data = normalize(data) + print("After normalization:") + print(f" pressure mean: {normalized_data['pressure'].mean():.4f}") + print(f" pressure std: {normalized_data['pressure'].std():.4f}") + print() + + # Inverse: denormalize + denormalized_data = normalize.inverse(normalized_data) + print("After denormalization (inverse):") + print(f" pressure mean: {denormalized_data['pressure'].mean():.4f}") + print(f" pressure std: {denormalized_data['pressure'].std():.4f}") + print() + + # Verify round-trip accuracy + original_pressure = data["pressure"] + roundtrip_pressure = denormalized_data["pressure"] + max_error = (original_pressure - roundtrip_pressure).abs().max() + + print("Round-trip verification:") + print(f" Max absolute error: {max_error:.2e}") + print(f" Round-trip accurate: {'✓' if max_error < 1e-5 else '✗'}") + print() + + # Practical example: Model prediction pipeline + print("Practical example: Model prediction pipeline") + print(" 1. Load data, feed to model") + print(" 2. Model outputs normalized prediction") + print(" 2. Normalize targets → compute model loss") + print(" 4. Denormalize output → get physical values and metrics") + print() + + reader.close() + + +# ============================================================================= +# Section 7: Complete Pipeline Example +# ============================================================================= +def section_7_complete_pipeline(): + """ + Section 7: Complete Preprocessing Pipeline + + This section demonstrates a realistic preprocessing pipeline combining + multiple transforms in a training-ready configuration. + """ + print("=" * 70) + print("Section 7: Complete Preprocessing Pipeline") + print("=" * 70) + print() + + data_path = "./output/pointcloud_data/" + if not check_data_exists(data_path, gen_cmd_cloud): + return None + + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + + print("Building a complete preprocessing pipeline:") + print() + print(" Pipeline steps:") + print(" 1. SubsamplePoints: Reduce to 10,000 points") + print(" 2. Translate: Center at origin") + print(" 3. ReScale: Normalize spatial extent") + print(" 4. Normalize: Standardize feature values") + print() + + # Define transforms + transforms = [ + # Step 1: Subsample to manageable size + SubsamplePoints( + input_keys=["coords", "features"], + n_points=10000, + algorithm="uniform", + ), + # Step 2: Translate (center at origin) + Translate( + input_keys=["coords"], + center_key_or_value=torch.tensor([0.0, 0.0, 0.0]), + ), + # Step 3: Scale coordinates (divide by 0.5 = multiply by 2) + ReScale( + input_keys=["coords"], + reference_scale=torch.tensor([0.5, 0.5, 0.5]), + ), + # Step 4: Normalize features + Normalize( + input_keys=["features"], + method="mean_std", + means={"features": 0.0}, + stds={"features": 0.58}, + ), + ] + + # Create dataset with the full pipeline + device = "cuda" if torch.cuda.is_available() else "cpu" + dataset = Dataset( + reader=reader, + transforms=transforms, + device=device, + ) + + print(f"Dataset created on device: {device}") + print(f"Number of samples: {len(dataset)}") + print() + + # Create DataLoader + dataloader = DataLoader( + dataset=dataset, + batch_size=4, + shuffle=True, + ) + + print("Sample batch from DataLoader:") + batch_data = next(iter(dataloader)) + + for key in batch_data.keys(): + tensor = batch_data[key] + print(f" '{key}': shape={tensor.shape}, device={tensor.device}") + print(f" mean={tensor.mean():.4f}, std={tensor.std():.4f}") + + print() + + # Timing comparison + print("Performance comparison:") + + # Time several iterations + start = time.time() + for i, batch_data in enumerate(dataloader): + if i >= 4: + break + # Simulate some computation + _ = batch_data["features"].sum() + elapsed = time.time() - start + + print(f" 5 batches loaded and processed in {elapsed:.3f}s") + print(f" Average time per batch: {elapsed / 5 * 1000:.1f}ms") + print() + + dataset.close() + print("Pipeline complete!") + print() + + +# ============================================================================= +# Main +# ============================================================================= +def main(): + """Run all tutorial sections.""" + print() + print("╔" + "═" * 68 + "╗") + print("║" + " Tutorial 2: Transforms and Data Preprocessing ".center(68) + "║") + print("╚" + "═" * 68 + "╝") + print() + + # Section 1: Single transform + section_1_single_transform() + + # Section 2: Compose transforms + section_2_compose_transforms() + + # Section 3: Point cloud subsampling + section_3_subsampling() + + # Section 4: Geometric transforms + section_4_geometric_transforms() + + # Section 5: Stats serialization + section_5_stats_serialization() + + # Section 6: Inverse normalization + section_6_inverse_normalization() + + # Section 7: Complete pipeline + section_7_complete_pipeline() + + print("=" * 70) + print("Tutorial 2 Complete!") + print() + print("Key takeaways:") + print(" 1. Transforms operate on TensorDict and return modified TensorDict") + print(" 2. Compose chains multiple transforms into a pipeline") + print(" 3. SubsamplePoints maintains correspondence between related fields") + print(" 4. Geometric transforms (Translate, ReScale) help with data prep") + print(" 5. Save/load normalization stats for reproducibility") + print(" 6. Use inverse() to convert predictions back to physical units") + print() + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/minimal/datapipes/tutorial_03_custom_gnn_datapipe.py b/examples/minimal/datapipes/tutorial_03_custom_gnn_datapipe.py new file mode 100644 index 0000000000..e203945618 --- /dev/null +++ b/examples/minimal/datapipes/tutorial_03_custom_gnn_datapipe.py @@ -0,0 +1,469 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tutorial 3: Custom Collation for Graph Neural Networks +======================================================= + +This tutorial demonstrates how to build a GNN-ready data pipeline using +PhysicsNeMo DataPipes. You'll learn how to: + +1. Build a transform that computes KNN graph edges +2. Use PyTorch Geometric's built-in batching via `Batch.from_data_list` +3. Put it all together in a GNN-ready pipeline + +Prerequisites +------------- +Before running this tutorial, generate point cloud data: + + python generate_variable_points_data.py -n 100 -s "coords:3 features:8" --min-points 50000 --max-points 100000 -b zarr -o output/pointcloud_data/ + +Run this tutorial: + python tutorial_03_custom_gnn_datapipe.py + +Key Concepts +------------ +- **Custom Transform**: Subclass `Transform` and implement `__call__()` +- **PyG Collator**: Use `torch_geometric.data.Batch.from_data_list()` for easy batching +- **PyG Batching**: Automatic edge index offsetting, feature concatenation, and batch tensor + +GNN Batching Background +----------------------- +Graph Neural Networks require special batching because graphs have variable +numbers of nodes and edges. PyTorch Geometric uses a "disjoint graph" approach: +- Concatenate all node features into one large tensor +- Offset edge indices so each graph's edges point to the correct nodes +- Add a `batch` tensor indicating which graph each node belongs to +""" + +import time +from pathlib import Path +from typing import Any, Sequence + +import torch +from torch_geometric.data import Batch as PyGBatch +from torch_geometric.data import Data as PyGData + +# Import core datapipe components +from physicsnemo.datapipes.core import DataLoader, Dataset +from physicsnemo.datapipes.core.collate import Collator +from physicsnemo.datapipes.core.readers import ZarrReader +from physicsnemo.datapipes.core.transforms import ( + KNNNeighbors, + SubsamplePoints, +) + + +def check_data_exists(data_path, gen_cmd): + """Check if tutorial data exists and provide helpful message if not.""" + path = Path(data_path) + if not path.exists(): + print(f"ERROR: Data not found at '{data_path}'") + print() + print("Please generate tutorial data first:") + print() + print(f" {gen_cmd}") + print() + return False + return True + + +# ============================================================================= +# Section 2: PyG-Style Graph Collator +# ============================================================================= + + +class PyGCollator(Collator): + """ + Collator that batches graphs using PyTorch Geometric's built-in batching. + + This collator converts each sample to a PyG Data object, then uses + `Batch.from_data_list()` to handle all the complexity of graph batching: + - Node features are concatenated: (N1 + N2 + ... + Nb, F) + - Edge indices are automatically offset and concatenated + - A `batch` tensor tracks which nodes belong to which graph + + Example: + Graph 0: 100 nodes, edges [[0,1,2], [1,2,0]] + Graph 1: 150 nodes, edges [[0,1], [1,0]] + + Batched (handled automatically by PyG): + - nodes: (250, F) + - edge_index: [[0,1,2,100,101], [1,2,0,101,100]] # Graph 1 offset by 100 + - batch: [0]*100 + [1]*150 + """ + + def __init__( + self, + edge_index_key: str = "edge_index", + collate_metadata: bool = False, + ) -> None: + """ + Initialize the PyG-style collator. + + Args: + edge_index_key: Key for edge indices in the input data. + Expected shape is [num_nodes, k] from KNN, which will be + converted to PyG's [2, num_edges] format. + """ + self.collate_metadata = collate_metadata + self.edge_index_key = edge_index_key + + @staticmethod + def knn_to_edge_index(knn_indices: torch.Tensor) -> torch.Tensor: + """ + Convert KNN indices to PyG edge_index format. + + Args: + knn_indices: Tensor of shape [num_nodes, k] where each row contains + the k nearest neighbor indices for that node. + + Returns: + edge_index: Tensor of shape [2, num_nodes * k] in PyG COO format, + where edge_index[0] is source nodes and edge_index[1] is target nodes. + """ + num_nodes, k = knn_indices.shape + # Source nodes: each node index repeated k times + source = torch.arange(num_nodes, device=knn_indices.device).repeat_interleave(k) + # Target nodes: flatten the KNN indices + target = knn_indices.reshape(-1) + return torch.stack([source, target], dim=0) + + def __call__( + self, samples: Sequence[tuple[dict, dict[str, Any]]] + ) -> tuple[PyGBatch, list[dict[str, Any]]]: + """ + Collate graphs into a batched PyG Batch object. + + Args: + samples: Sequence of (TensorDict/dict, metadata) tuples. + + Returns: + Tuple of (PyG Batch, list of metadata dicts). + """ + if not samples: + raise ValueError("Cannot collate empty sequence of samples") + + # Separate data and metadata + data_list = [data for data, _ in samples] + + # Convert each sample to a PyG Data object + pyg_data_list = [] + for data in data_list: + # Build kwargs for PyG Data, renaming edge_index_key to 'edge_index' + data_kwargs = {} + for key in data.keys(): + tensor = data[key] + if key == self.edge_index_key: + # Convert from KNN format [num_nodes, k] to PyG format [2, num_edges] + data_kwargs["edge_index"] = self.knn_to_edge_index(tensor) + else: + data_kwargs[key] = tensor + + pyg_data_list.append(PyGData(**data_kwargs)) + + # Use PyG's built-in batching - handles edge index offsetting automatically + batched_data = PyGBatch.from_data_list(pyg_data_list) + + if self.collate_metadata: + metadata_list = [meta for _, meta in samples] + return batched_data, list(metadata_list) + else: + return batched_data + + def __repr__(self) -> str: + return f"PyGCollator(edge_index_key={self.edge_index_key})" + + +data_path = "./output/pointcloud_data/" +gen_cmd = 'python generate_variable_points_data.py -n 100 -s "coords:3 features:8" --min-points 50000 --max-points 100000 -b zarr -o output/pointcloud_data/' +# ============================================================================= +# Section 3: Demonstration +# ============================================================================= + + +def section_1_knn_transform(): + """ + Section 1: Computing KNN Graph Edges + + Shows how to use the ComputeKNNEdges transform to build + graph structure from point cloud positions. + """ + print("=" * 70) + print("Section 1: Computing KNN Graph Edges") + print("=" * 70) + print() + + if not check_data_exists(data_path, gen_cmd): + return None + + # Load a sample using ZarrReader + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + data, metadata = reader[0] + + print(f"Loaded sample with {data['coords'].shape[0]} points") + print(f"Fields: {list(data.keys())}") + print() + + # Create and apply the KNN edge transform + knn_transform = KNNNeighbors( + points_key="coords", + queries_key="coords", # Apply the kNN to itself. + k=8, + extract_keys=["features"], + ) + print(f"Transform: {knn_transform}") + print() + + data_with_edges = knn_transform(data) + + edge_index = "neighbors_indices" + + print("After transform:") + print(f" Fields: {list(data_with_edges.keys())}") + print(f" edge_index shape: {data_with_edges[edge_index].shape}") + print() + + # Verify graph structure + n_nodes = data_with_edges["coords"].shape[0] + n_edges = data_with_edges[edge_index].shape[1] + + print(f"Graph structure:") + print(f" Nodes: {n_nodes}") + print(f" Edges / node: {n_edges}") + print() + + reader.close() + return knn_transform + + +def section_2_pyg_collator(): + """ + Section 2: PyG-Style Graph Batching + + Demonstrates how the PyGCollator uses PyG's Batch.from_data_list() + to combine multiple graphs into a single batched graph. + """ + print("=" * 70) + print("Section 2: PyG-Style Graph Collator") + print("=" * 70) + print() + + if not check_data_exists(data_path, gen_cmd): + return None + + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + knn_transform = KNNNeighbors( + points_key="coords", + queries_key="coords", # Apply the kNN to itself. + k=8, + extract_keys=["features"], + ) + + # Load and transform a few samples + print("Loading 3 individual graphs:") + samples = [] + for i in range(3): + data, meta = reader[i] + data = knn_transform(data) + samples.append((data, meta)) + n_nodes = data["coords"].shape[0] + n_edges = data["neighbors_indices"].shape[1] + print(f" Graph {i}: {n_nodes} nodes, {n_edges} edges") + + print() + + # Apply collator - uses PyG's Batch.from_data_list() internally + collator = PyGCollator(edge_index_key="neighbors_indices", collate_metadata=True) + print(f"Collator: {collator}") + print() + + batched_data, batch_metadata = collator(samples) + + print(f"Batched graph (type: {type(batched_data).__name__}):") + for key in batched_data.keys(): + tensor = batched_data[key] + print(f" {key}: shape={tensor.shape}") + print(f"Batch metadata: {batch_metadata}") + + print() + print("Batch tensor distribution (nodes per graph):") + batch = batched_data.batch + for i in range(3): + count = (batch == i).sum().item() + print(f" Graph {i}: {count} nodes") + print() + + reader.close() + + +def section_3_complete_pipeline(): + """ + Section 3: Complete GNN Data Pipeline + + Puts everything together: reader, transforms, collator, + and DataLoader for a complete GNN training pipeline. + """ + print("=" * 70) + print("Section 3: Complete GNN Data Pipeline") + print("=" * 70) + print() + + if not check_data_exists(data_path, gen_cmd): + return + + # 1. Create reader + print("Step 1: Create reader") + reader = ZarrReader(path=data_path, group_pattern="*.zarr") + print(f" Reader: {len(reader)} samples") + print() + + # 2. Define transforms + print("Step 2: Define transforms") + transforms = [ + # Subsample to fixed size for consistent batching + SubsamplePoints( + input_keys=["coords", "features"], + n_points=500, + algorithm="uniform", + ), + # Compute graph edges + KNNNeighbors( + points_key="coords", + queries_key="coords", + k=8, + extract_keys=["features"], + ), + ] + print(f" Transforms: {[type(t).__name__ for t in transforms]}") + print() + + # 3. Create dataset + print("Step 3: Create dataset") + device = "cuda" if torch.cuda.is_available() else "cpu" + dataset = Dataset( + reader=reader, + transforms=transforms, + device=device, + ) + print(f" Dataset: {len(dataset)} samples on {device}") + print() + + # 4. Create dataloader with PyG collator + print("Step 4: Create DataLoader with PyG collator") + collator = PyGCollator(edge_index_key="neighbors_indices") + dataloader = DataLoader( + dataset=dataset, + batch_size=4, + shuffle=True, + collate_fn=collator, + collate_metadata=False, + ) + print(f" DataLoader: batch_size=4, {len(dataloader)} batches") + print() + + # 5. Iterate over batches + print("Step 5: Iterate over batches") + print("-" * 50) + + start = time.time() + for batch_idx, batch_data in enumerate(dataloader): + elapsed = time.time() - start + + print(f"\nBatch {batch_idx} (loaded in {elapsed:.3f}s):") + print(f" Batch type: {type(batch_data)}") + print(f" Total nodes: {batch_data.coords.shape[0]}") + print(f" Total edges: {batch_data.edge_index.shape[1]}") + + # Show per-graph breakdown + batch = batch_data.batch + n_graphs = batch.max().item() + 1 + print(f" Graphs in batch: {n_graphs}") + for i in range(n_graphs): + n_nodes = (batch == i).sum().item() + print(f" Graph {i}: {n_nodes} nodes") + + # Show data shapes for GNN input + print(f" Data shapes:") + print(f" coords: {batch_data.coords.shape}") + print(f" features: {batch_data.features.shape}") + print(f" edge_index: {batch_data.edge_index.shape}") + print(f" batch: {batch_data.batch.shape}") + + if batch_idx >= 1: + print("\n ... (showing first 2 batches)") + break + + start = time.time() + + print() + print("-" * 50) + print() + + # 6. Show PyG integration + print("Step 6: PyTorch Geometric Integration") + print("-" * 50) + print() + print("The batch is already a PyG Batch object! Use it directly with PyG models:") + print() + print(" for pyg_batch, _ in dataloader:") + print(" # pyg_batch is already a torch_geometric.data.Batch") + print(" # Access attributes directly:") + print(" # pyg_batch.coords, pyg_batch.features, pyg_batch.edge_index") + print(" # pyg_batch.batch (node-to-graph assignment)") + print(" output = model(pyg_batch.features, pyg_batch.edge_index)") + print() + + dataset.close() + + +# ============================================================================= +# Main +# ============================================================================= + + +def main(): + """Run all tutorial sections.""" + print() + print("╔" + "═" * 68 + "╗") + print("║" + " Tutorial 3: Custom Collation for GNNs ".center(68) + "║") + print("╚" + "═" * 68 + "╝") + print() + + # Section 1: KNN transform + section_1_knn_transform() + + # Section 2: PyG collator + section_2_pyg_collator() + + # Section 3: Complete pipeline + section_3_complete_pipeline() + + print("=" * 70) + print("Tutorial 3 Complete!") + print() + print("Key takeaways:") + print(" 1. KNNNeighbors transform: Computes graph structure from point clouds") + print(" 2. PyGCollator: Uses Batch.from_data_list() for simple, correct batching") + print(" 3. Returns a PyG Batch object that works directly with PyG models") + print(" 4. The batch tensor tracks which nodes belong to which graph") + print() + print("Next: Tutorial 4 - Configuration with Hydra") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/minimal/datapipes/tutorial_04_hydra_config.py b/examples/minimal/datapipes/tutorial_04_hydra_config.py new file mode 100644 index 0000000000..060a4e3e41 --- /dev/null +++ b/examples/minimal/datapipes/tutorial_04_hydra_config.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tutorial 4: Hydra Configuration for DataPipes +============================================== + +This tutorial demonstrates how to configure PhysicsNeMo DataPipes entirely +through Hydra YAML files with minimal Python code. + +The key insight: Hydra's `instantiate()` can build the entire datapipe +(reader, transforms, dataset, dataloader) from configuration alone using +recursive instantiation. + +Prerequisites +------------- +Generate synthetic data before running: + + # For point cloud data: + python generate_variable_points_data.py -n 100 -s "coords:3 features:8" --min-points 50000 --max-points 100000 -b zarr -o output/pointcloud_data/ + +Run this tutorial: + + # Use point cloud configuration (default) + python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud + + # Override values from command line + python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud \\ + dataloader.batch_size=8 + + # Override transform parameters + python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud \\ + subsample.n_points=5000 + +Configuration Files +------------------- +- conf/tutorial_04_pointcloud.yaml - Point cloud pipeline with subsampling +""" + +import hydra +from omegaconf import DictConfig, OmegaConf + +from physicsnemo.datapipes.core import DataLoader + + +@hydra.main( + version_base=None, + config_path="./conf", + config_name="tutorial_04_pointcloud", +) +def main(cfg: DictConfig): + """ + Main entry point - demonstrates Hydra-based datapipe configuration. + + The entire pipeline is built from the YAML configuration with a single + instantiate call that recursively builds DataLoader -> Dataset -> Reader + Transforms. + """ + print() + print("=" * 70) + print("Tutorial 4: Hydra Configuration for DataPipes") + print("=" * 70) + print() + + # Show the resolved configuration + print("Resolved Configuration:") + print("-" * 70) + print(OmegaConf.to_yaml(cfg)) + print("-" * 70) + print() + + # Build entire datapipe from config with a single instantiate call + # Hydra recursively instantiates: DataLoader -> Dataset -> Reader + Transforms + print("Building datapipe from configuration (single instantiate call)...") + dataloader: DataLoader = hydra.utils.instantiate(cfg.dataloader) + dataset = dataloader.dataset + + print(f" Reader: {dataset.reader}") + # Handle different transform configurations + if dataset.transforms is None: + transform_names = [] + elif hasattr(dataset.transforms, "transforms"): + # Compose wraps multiple transforms + transform_names = [type(t).__name__ for t in dataset.transforms.transforms] + else: + # Single transform + transform_names = [type(dataset.transforms).__name__] + print(f" Transforms: {transform_names}") + print(f" Dataset: {len(dataset)} samples") + print( + f" DataLoader: {len(dataloader)} batches (batch_size={cfg.dataloader.batch_size})" + ) + print() + + # Run training loop + print("Training Loop:") + print("-" * 70) + + num_epochs = cfg.training.get("num_epochs", 2) + log_interval = cfg.training.get("log_interval", 1) + + for epoch in range(num_epochs): + for batch_idx, batch_data in enumerate(dataloader): + if batch_idx % log_interval == 0: + print(f"Epoch {epoch}, Batch {batch_idx}:") + for key in batch_data.keys(): + shape = tuple(batch_data[key].shape) + print(f" {key}: {shape}") + + print(f"Epoch {epoch} complete: {len(dataloader)} batches") + + print("-" * 70) + print() + + # Cleanup + dataset.close() + + # Print summary + print("=" * 70) + print("Tutorial 4 Complete!") + print() + print("Key takeaways:") + print(" 1. Define datapipes entirely in YAML configuration") + print(" 2. Use a single hydra.utils.instantiate() call to build everything") + print( + " 3. Hydra recursively instantiates: DataLoader -> Dataset -> Reader + Transforms" + ) + print(" 4. Override any parameter from command line:") + print(" python tutorial_04_hydra_config.py dataloader.batch_size=8") + print(" 5. Override transform parameters (using top-level keys from defaults):") + print(" python tutorial_04_hydra_config.py subsample.n_points=5000") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/physicsnemo/datapipes/core/__init__.py b/physicsnemo/datapipes/core/__init__.py new file mode 100644 index 0000000000..39792e52b5 --- /dev/null +++ b/physicsnemo/datapipes/core/__init__.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +datapipe - High-performance GPU-centric data loading for Scientific ML + +A modular, composable data pipeline for physics and scientific machine learning. +Designed for clean separation of concerns: + +- **Readers**: Load data from sources → TensorDict tuples with CPU tensors +- **Transforms**: Process TensorDict data +- **Dataset**: Reader + transforms pipeline with optional auto device transfer +- **DataLoader**: Batched iteration with optional prefetching + +Example: + >>> import physicsnemo.datapipes.core as dp + >>> from tensordict import TensorDict + >>> + >>> # Create a dataset with transforms and automatic device transfer + >>> dataset = dp.Dataset( + ... reader=dp.HDF5Reader("data.h5", fields=["pressure", "velocity"]), + ... transforms=[ + ... dp.Normalize(input_keys=["pressure"], means={"pressure": 0.0}, stds={"pressure": 1.0}), + ... dp.SubsamplePoints(input_keys=["pressure", "velocity"], n=10000), + ... ], + ... device="cuda", # Automatic GPU transfer! + ... ) + >>> + >>> # Create a dataloader + >>> loader = dp.DataLoader(dataset, batch_size=16, shuffle=True) + >>> + >>> # Iterate over batches + >>> for data, metadata in loader: + ... output = model(data["pressure"]) +""" + +from tensordict import TensorDict + +from physicsnemo.datapipes.core.collate import ( + Collator, + ConcatCollator, + DefaultCollator, + FunctionCollator, + concat_collate, + default_collate, + get_collator, +) +from physicsnemo.datapipes.core.dataloader import DataLoader +from physicsnemo.datapipes.core.dataset import Dataset +from physicsnemo.datapipes.core.readers import ( + HDF5Reader, + NumpyReader, + Reader, + ZarrReader, +) +from physicsnemo.datapipes.core.registry import ( + READER_REGISTRY, + TRANSFORM_REGISTRY, + ComponentRegistry, + register_reader, + register_transform, +) +from physicsnemo.datapipes.core.transforms import ( + Compose, + Normalize, + SubsamplePoints, + Transform, +) + +__version__ = "0.1.0" + +__all__ = [ + # Core + "TensorDict", # Re-export from tensordict + "Dataset", + "DataLoader", + # Transforms + "Transform", + "Compose", + "Normalize", + "SubsamplePoints", + # Readers + "Reader", + "HDF5Reader", + "ZarrReader", + "NumpyReader", + # Collation + "Collator", + "DefaultCollator", + "ConcatCollator", + "FunctionCollator", + "default_collate", + "concat_collate", + "get_collator", + # Registry + "ComponentRegistry", + "TRANSFORM_REGISTRY", + "READER_REGISTRY", + "register_transform", + "register_reader", +] diff --git a/physicsnemo/datapipes/core/collate.py b/physicsnemo/datapipes/core/collate.py new file mode 100644 index 0000000000..69628929ba --- /dev/null +++ b/physicsnemo/datapipes/core/collate.py @@ -0,0 +1,414 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Collation utilities - Batch multiple (TensorDict, metadata) tuples. + +Collators combine multiple (TensorDict, dict) tuples from Dataset into a single +batched output suitable for model consumption. By default, returns just the +batched TensorDict for PyTorch DataLoader compatibility. When collate_metadata=True, +returns a tuple of (TensorDict, list[dict]). + +The default collator stacks TensorDicts along batch dimension using TensorDict.stack(). +Metadata collation is optional and disabled by default. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Sequence, Union + +import torch +from tensordict import TensorDict + + +def _collate_metadata(metadata_list: Sequence[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Collate metadata from multiple samples. + + Simply returns the list of metadata dicts as-is. Each metadata dict + corresponds to one sample in the batch. + + Args: + metadata_list: Sequence of metadata dicts. + + Returns: + List of metadata dicts. + """ + return list(metadata_list) + + +class Collator(ABC): + """ + Abstract base class for collators. + + Collators take a sequence of (TensorDict, dict) tuples and combine them + into a batched output. By default, returns just the batched TensorDict + for PyTorch DataLoader compatibility. When collate_metadata=True, returns + a tuple of (TensorDict, list[dict]). + + Example: + >>> class MyCollator(Collator): + ... def __call__( + ... self, + ... samples: Sequence[tuple[TensorDict, dict]] + ... ) -> TensorDict: + ... # Custom batching logic + ... ... + """ + + @abstractmethod + def __call__( + self, samples: Sequence[tuple[TensorDict, dict[str, Any]]] + ) -> Union[TensorDict, tuple[TensorDict, list[dict[str, Any]]]]: + """ + Collate a batch of samples. + + Args: + samples: Sequence of (TensorDict, metadata dict) tuples to batch. + + Returns: + Batched TensorDict, or tuple of (batched TensorDict, list of metadata dicts) + if collate_metadata=True. + """ + raise NotImplementedError + + +class DefaultCollator(Collator): + """ + Default collator that stacks TensorDicts along a new batch dimension. + + Uses TensorDict.stack() to efficiently batch all tensors, creating + shape [batch_size, ...original_shape] for each field. + + All samples must have: + - The same tensor keys + - Tensors with matching shapes (per key) + - Tensors on the same device + + By default, returns just the batched TensorDict for PyTorch DataLoader + compatibility. Set collate_metadata=True to also return metadata. + + Example: + >>> data1 = TensorDict({"x": torch.randn(10, 3)}, device="cpu") + >>> data2 = TensorDict({"x": torch.randn(10, 3)}, device="cpu") + >>> samples = [ + ... (data1, {"file": "a.h5"}), + ... (data2, {"file": "b.h5"}), + ... ] + >>> collator = DefaultCollator() + >>> batched_data = collator(samples) + >>> batched_data["x"].shape # torch.Size([2, 10, 3]) + >>> + >>> # With metadata collation enabled: + >>> collator = DefaultCollator(collate_metadata=True) + >>> batched_data, metadata_list = collator(samples) + >>> metadata_list # [{"file": "a.h5"}, {"file": "b.h5"}] + """ + + def __init__( + self, + *, + stack_dim: int = 0, + keys: Optional[list[str]] = None, + collate_metadata: bool = False, + ) -> None: + """ + Initialize the collator. + + Args: + stack_dim: Dimension along which to stack tensors (default: 0). + keys: If provided, only collate these tensor keys. Others are ignored. + collate_metadata: If True, collate metadata into list (default: False). + Default is False for compatibility with PyTorch DataLoader. + """ + self.stack_dim = stack_dim + self.keys = keys + self.collate_metadata = collate_metadata + + def __call__( + self, samples: Sequence[tuple[TensorDict, dict[str, Any]]] + ) -> Union[TensorDict, tuple[TensorDict, list[dict[str, Any]]]]: + """ + Collate samples by stacking TensorDicts. + + Args: + samples: Sequence of (TensorDict, metadata) tuples to batch. + + Returns: + Batched TensorDict if collate_metadata=False (default), + or tuple of (batched TensorDict, list of metadata dicts) if collate_metadata=True. + + Raises: + ValueError: If samples is empty or samples have mismatched keys/shapes. + """ + if not samples: + raise ValueError("Cannot collate empty sequence of samples") + + # Separate data and metadata + data_list = [data for data, _ in samples] + + # Use TensorDict.stack() for efficient batching + if self.keys is not None: + # Filter to only requested keys + data_list = [data.select(*self.keys) for data in data_list] + + batched_data = torch.stack(data_list, dim=self.stack_dim) + + # Collate metadata only if requested + if self.collate_metadata: + metadata_list = [meta for _, meta in samples] + return batched_data, _collate_metadata(metadata_list) + + return batched_data + + +class ConcatCollator(Collator): + """ + Collator that concatenates tensors along an existing dimension. + + Unlike DefaultCollator which creates a new batch dimension, this + concatenates along an existing dimension. Useful for point clouds + or other variable-length data where you want to combine all points. + + Optionally adds batch indices to track which points came from which sample. + By default, returns just the batched TensorDict for PyTorch DataLoader + compatibility. Set collate_metadata=True to also return metadata. + + Example: + >>> data1 = TensorDict({"points": torch.randn(100, 3)}) + >>> data2 = TensorDict({"points": torch.randn(150, 3)}) + >>> samples = [ + ... (data1, {"file": "a.h5"}), + ... (data2, {"file": "b.h5"}), + ... ] + >>> collator = ConcatCollator(dim=0, add_batch_idx=True) + >>> batched_data = collator(samples) + >>> batched_data["points"].shape # torch.Size([250, 3]) + >>> batched_data["batch_idx"].shape # torch.Size([250]) + >>> + >>> # With metadata collation enabled: + >>> collator = ConcatCollator(dim=0, add_batch_idx=True, collate_metadata=True) + >>> batched_data, metadata_list = collator(samples) + >>> metadata_list # [{"file": "a.h5"}, {"file": "b.h5"}] + """ + + def __init__( + self, + *, + dim: int = 0, + add_batch_idx: bool = True, + batch_idx_key: str = "batch_idx", + keys: Optional[list[str]] = None, + collate_metadata: bool = False, + ) -> None: + """ + Initialize the collator. + + Args: + dim: Dimension along which to concatenate. + add_batch_idx: If True, add a tensor of batch indices. + batch_idx_key: Key for the batch index tensor. + keys: If provided, only collate these tensor keys. + collate_metadata: If True, collate metadata into lists (default: False). + Default is False for compatibility with PyTorch DataLoader. + """ + self.dim = dim + self.add_batch_idx = add_batch_idx + self.batch_idx_key = batch_idx_key + self.keys = keys + self.collate_metadata = collate_metadata + + def __call__( + self, samples: Sequence[tuple[TensorDict, dict[str, Any]]] + ) -> Union[TensorDict, tuple[TensorDict, list[dict[str, Any]]]]: + """ + Collate samples by concatenating tensors. + + Args: + samples: Sequence of (TensorDict, metadata) tuples to batch. + + Returns: + Batched TensorDict if collate_metadata=False (default), + or tuple of (batched TensorDict, list of metadata dicts) if collate_metadata=True. + + Raises: + ValueError: If samples is empty. + """ + if not samples: + raise ValueError("Cannot collate empty sequence of samples") + + # Separate data + data_list = [data for data, _ in samples] + + first_data = data_list[0] + keys = self.keys if self.keys else list(first_data.keys()) + device = first_data.device + + batched_tensors = {} + sizes = [] # Track sizes for batch indices + + for key in keys: + tensors = [] + for data in data_list: + if key not in data.keys(): + raise ValueError(f"Data missing key '{key}'") + tensor = data[key] + tensors.append(tensor) + if key == keys[0]: # Track sizes from first key + sizes.append(tensor.shape[self.dim]) + + batched_tensors[key] = torch.cat(tensors, dim=self.dim) + + # Add batch indices + if self.add_batch_idx: + batch_indices = [] + for i, size in enumerate(sizes): + batch_indices.append( + torch.full((size,), i, dtype=torch.long, device=device) + ) + batched_tensors[self.batch_idx_key] = torch.cat(batch_indices, dim=0) + + # Create batched TensorDict + batched_data = TensorDict(batched_tensors, device=device) + + # Collate metadata only if requested + if self.collate_metadata: + metadata_list = [meta for _, meta in samples] + return batched_data, _collate_metadata(metadata_list) + + return batched_data + + +class FunctionCollator(Collator): + """ + Collator that wraps a user-provided function. + + Allows using any function as a collator without subclassing. + + Example: + >>> def my_collate(samples): + ... # Custom logic + ... data_list = [d for d, _ in samples] + ... metadata_list = [m for _, m in samples] + ... return TensorDict.stack(data_list), metadata_list + >>> collator = FunctionCollator(my_collate) + """ + + def __init__( + self, + fn: Callable[ + [Sequence[tuple[TensorDict, dict[str, Any]]]], + tuple[TensorDict, list[dict[str, Any]]], + ], + ) -> None: + """ + Initialize with a collation function. + + Args: + fn: Function that takes a sequence of (TensorDict, dict) tuples + and returns a (TensorDict, list[dict]) tuple. + """ + self.fn = fn + + def __call__( + self, samples: Sequence[tuple[TensorDict, dict[str, Any]]] + ) -> Union[TensorDict, tuple[TensorDict, list[dict[str, Any]]]]: + """Apply the wrapped function.""" + return self.fn(samples) + + +# Default collator instance +_default_collator = DefaultCollator() + + +def default_collate( + samples: Sequence[tuple[TensorDict, dict[str, Any]]], +) -> tuple[TensorDict, list[dict[str, Any]]]: + """ + Default collation function using stacking. + + Convenience function that uses DefaultCollator. + Metadata is collated into a list of dicts. + + Args: + samples: Sequence of (TensorDict, metadata) tuples to batch. + + Returns: + Tuple of (batched TensorDict, list of metadata dicts). + """ + return _default_collator(samples) + + +def concat_collate( + samples: Sequence[tuple[TensorDict, dict[str, Any]]], + dim: int = 0, + add_batch_idx: bool = True, +) -> tuple[TensorDict, list[dict[str, Any]]]: + """ + Collation function using concatenation. + + Convenience function that uses ConcatCollator. + Metadata is collated into a list of dicts. + + Args: + samples: Sequence of (TensorDict, metadata) tuples to batch. + dim: Dimension along which to concatenate. + add_batch_idx: If True, add batch index tensor. + + Returns: + Tuple of (batched TensorDict, list of metadata dicts). + """ + collator = ConcatCollator(dim=dim, add_batch_idx=add_batch_idx) + return collator(samples) + + +def get_collator( + collate_fn: Optional[ + Union[ + Collator, + Callable[ + [Sequence[tuple[TensorDict, dict[str, Any]]]], + tuple[TensorDict, list[dict[str, Any]]], + ], + ] + ] = None, + *, + collate_metadata: bool = False, +) -> Collator: + """ + Get a Collator instance from various input types. + + Args: + collate_fn: Collator, callable, or None (uses default). + collate_metadata: If True, collate metadata into list (default: False). + Only used when collate_fn is None. + Default is False for compatibility with PyTorch DataLoader. + + Returns: + Collator instance. + """ + if collate_fn is None: + return DefaultCollator(collate_metadata=collate_metadata) + elif isinstance(collate_fn, Collator): + return collate_fn + elif callable(collate_fn): + return FunctionCollator(collate_fn) + else: + raise TypeError( + f"collate_fn must be Collator, callable, or None, " + f"got {type(collate_fn).__name__}" + ) diff --git a/physicsnemo/datapipes/core/dataloader.py b/physicsnemo/datapipes/core/dataloader.py new file mode 100644 index 0000000000..6f26e1bcbf --- /dev/null +++ b/physicsnemo/datapipes/core/dataloader.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DataLoader - Batched iteration over datasets with prefetching. + +The DataLoader orchestrates efficient batch loading by leveraging +the Dataset's prefetching capabilities with CUDA streams. +By default, returns batched TensorDict for PyTorch DataLoader compatibility. +When collate_metadata=True, returns (TensorDict, list[dict]) tuples. +""" + +from __future__ import annotations + +from typing import Any, Callable, Iterator, Optional, Sequence, Union + +import torch +from tensordict import TensorDict +from torch.utils.data import RandomSampler, Sampler, SequentialSampler + +from physicsnemo.datapipes.core.collate import Collator, get_collator +from physicsnemo.datapipes.core.dataset import Dataset + + +class DataLoader: + """ + Batched iteration over a Dataset with stream-based prefetching. + + Unlike PyTorch's DataLoader which uses CPU multiprocessing, this + DataLoader uses CUDA streams to overlap data loading, preprocessing, + and collation. This is more efficient for SciML workloads where: + - Datasets are huge + - Batches are small + - Preprocessing benefits from GPU acceleration + + Features: + - Stream-based parallelism (one stream per sample in flight) + - Toggleable prefetching for debugging + - Compatible with PyTorch samplers (DistributedSampler, etc.) + - Familiar torch DataLoader interface + + Example: + >>> from physicsnemo.datapipes import DataLoader, Dataset, HDF5Reader, ToDevice, Compose + >>> + >>> dataset = Dataset( + ... HDF5Reader("data.h5"), + ... transforms=Compose([ToDevice("cuda"), ...]) + ... ) + >>> loader = DataLoader(dataset, batch_size=16, shuffle=True) + >>> + >>> for batch in loader: + ... output = model(batch["input"]) + + With DistributedSampler: + >>> from torch.utils.data.distributed import DistributedSampler + >>> sampler = DistributedSampler(dataset) + >>> loader = DataLoader(dataset, batch_size=16, sampler=sampler) + """ + + def __init__( + self, + dataset: Dataset, + *, + batch_size: int = 1, + shuffle: bool = False, + sampler: Optional[Sampler] = None, + drop_last: bool = False, + collate_fn: Optional[ + Union[ + Collator, + Callable[ + [Sequence[tuple[TensorDict, dict[str, Any]]]], + tuple[TensorDict, list[dict[str, Any]]], + ], + ] + ] = None, + collate_metadata: bool = False, + prefetch_factor: int = 2, + num_streams: int = 4, + use_streams: bool = True, + ) -> None: + """ + Initialize the DataLoader. + + Args: + dataset: Dataset to load from. + batch_size: Number of samples per batch (default: 1). + shuffle: If True, shuffle indices each epoch. Ignored if sampler provided. + sampler: Custom sampler for index generation. If provided, shuffle is ignored. + drop_last: If True, drop the last incomplete batch. + collate_fn: Function to collate samples into batches. Defaults to stacking. + collate_metadata: If True, collate metadata into a list of dicts (default: False). + Set to False for compatibility with PyTorch DataLoader. + Only used when collate_fn is None (uses default collator). + prefetch_factor: Number of batches to prefetch ahead (default: 2). + Set to 0 to disable prefetching. + num_streams: Number of CUDA streams for prefetching (default: 4). + use_streams: If True, use CUDA streams for overlap (default: True). + Set False for debugging or CPU-only operation. + + Raises: + ValueError: If batch_size < 1. + """ + if batch_size < 1: + raise ValueError(f"batch_size must be >= 1, got {batch_size}") + + self.dataset = dataset + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + self.prefetch_factor = prefetch_factor + self.num_streams = num_streams + self.use_streams = use_streams and torch.cuda.is_available() + + # Handle sampler + if sampler is not None: + self.sampler = sampler + elif shuffle: + self.sampler = RandomSampler(dataset) + else: + self.sampler = SequentialSampler(dataset) + + # Handle collation + self.collate_fn = get_collator(collate_fn, collate_metadata=collate_metadata) + + # Create CUDA streams for prefetching + self._streams: list[torch.cuda.Stream] = [] + if self.use_streams: + for _ in range(num_streams): + self._streams.append(torch.cuda.Stream()) + + def __len__(self) -> int: + """Return the number of batches.""" + n_samples = len(self.dataset) + if self.drop_last: + return n_samples // self.batch_size + return (n_samples + self.batch_size - 1) // self.batch_size + + def _generate_batches(self) -> Iterator[list[int]]: + """Generate batches of indices.""" + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + + if batch and not self.drop_last: + yield batch + + def __iter__( + self, + ) -> Iterator[Union[TensorDict, tuple[TensorDict, list[dict[str, Any]]]]]: + """ + Iterate over batches. + + Uses stream-based prefetching when enabled to overlap IO, + GPU transfers, and computation. + + Yields: + Batched TensorDict if collate_metadata=False (default), + or tuple of (batched TensorDict, list of metadata dicts) if collate_metadata=True. + """ + if self.prefetch_factor > 0 and self.use_streams: + yield from self._iter_prefetch() + else: + yield from self._iter_simple() + + def _iter_simple( + self, + ) -> Iterator[Union[TensorDict, tuple[TensorDict, list[dict[str, Any]]]]]: + """Simple synchronous iteration without prefetching.""" + for batch_indices in self._generate_batches(): + samples = [self.dataset[idx] for idx in batch_indices] + yield self.collate_fn(samples) + + def _iter_prefetch( + self, + ) -> Iterator[Union[TensorDict, tuple[TensorDict, list[dict[str, Any]]]]]: + """ + Iteration with stream-based prefetching. + + Strategy: + 1. Prefetch `prefetch_factor` batches worth of samples + 2. As we yield batches, prefetch more to keep the pipeline full + 3. Each sample in a batch uses a different stream for overlap + """ + # Collect all batches upfront for prefetch planning + all_batches = list(self._generate_batches()) + if not all_batches: + return + + num_prefetch_batches = min(self.prefetch_factor, len(all_batches)) + stream_idx = 0 + + # Start initial prefetch + prefetched_up_to = 0 + for batch_idx in range(num_prefetch_batches): + for sample_idx in all_batches[batch_idx]: + stream = self._streams[stream_idx % self.num_streams] + self.dataset.prefetch(sample_idx, stream=stream) + stream_idx += 1 + prefetched_up_to = batch_idx + 1 + + # Yield batches and prefetch more + for batch_idx, batch_indices in enumerate(all_batches): + # Collect samples (uses prefetched if available) + samples = [self.dataset[idx] for idx in batch_indices] + batch = self.collate_fn(samples) + + # Prefetch next batch if available + next_prefetch_idx = prefetched_up_to + if next_prefetch_idx < len(all_batches): + for sample_idx in all_batches[next_prefetch_idx]: + stream = self._streams[stream_idx % self.num_streams] + self.dataset.prefetch(sample_idx, stream=stream) + stream_idx += 1 + prefetched_up_to += 1 + + yield batch + + # Clean up any remaining prefetch state + self.dataset.cancel_prefetch() + + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for the sampler. + + Required for DistributedSampler to shuffle properly across epochs. + + Args: + epoch: Current epoch number. + """ + if hasattr(self.sampler, "set_epoch"): + self.sampler.set_epoch(epoch) + + def enable_prefetch(self) -> None: + """Enable stream-based prefetching.""" + if not torch.cuda.is_available(): + raise RuntimeError( + "CUDA is not available, cannot enable stream prefetching" + ) + + if not self._streams: + for _ in range(self.num_streams): + self._streams.append(torch.cuda.Stream()) + + self.use_streams = True + + def disable_prefetch(self) -> None: + """Disable prefetching (useful for debugging).""" + self.use_streams = False + self.dataset.cancel_prefetch() + + def __repr__(self) -> str: + return ( + f"DataLoader(\n" + f" dataset={self.dataset},\n" + f" batch_size={self.batch_size},\n" + f" shuffle={self.shuffle},\n" + f" drop_last={self.drop_last},\n" + f" prefetch_factor={self.prefetch_factor},\n" + f" num_streams={self.num_streams},\n" + f" use_streams={self.use_streams}\n" + f")" + ) diff --git a/physicsnemo/datapipes/core/dataset.py b/physicsnemo/datapipes/core/dataset.py new file mode 100644 index 0000000000..7d55df25f8 --- /dev/null +++ b/physicsnemo/datapipes/core/dataset.py @@ -0,0 +1,363 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Dataset - Combines a Reader with a transform pipeline. + +The Dataset is the primary interface for accessing preprocessed data. +It wraps a Reader and applies transforms to produce ready-to-use TensorDicts. +Supports prefetching with CUDA streams for overlapped IO and computation, +and automatic device transfer when device parameter is specified. +""" + +from __future__ import annotations + +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import Any, Iterator, Optional, Sequence, Union + +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes.core.readers.base import Reader +from physicsnemo.datapipes.core.transforms.base import Transform +from physicsnemo.datapipes.core.transforms.compose import Compose + + +@dataclass +class _PrefetchResult: + """Result of a prefetch operation.""" + + index: int + data: Optional[TensorDict] = None + metadata: Optional[dict[str, Any]] = None + error: Optional[Exception] = None + event: Optional[torch.cuda.Event] = None # For stream sync + + +class Dataset: + """ + A dataset combining a Reader with a transform pipeline. + + The Dataset provides a torch-like interface for accessing data: + - Indexing: dataset[i] returns transformed sample i + - Iteration: for sample in dataset + - Length: len(dataset) + - Prefetching: dataset.prefetch(i, stream) for async loading + + The pipeline is: Reader → Transforms → Sample + + Prefetching Model: + The dataset supports prefetching samples using a thread pool. + When a CUDA stream is provided, GPU operations (device transfer, + GPU transforms) happen on that stream, allowing overlap with + other computation. + + >>> # Start prefetching + >>> dataset.prefetch(0, stream=stream0) + >>> dataset.prefetch(1, stream=stream1) + >>> + >>> # Retrieve results (waits if not ready) + >>> sample_0 = dataset[0] # Uses prefetched result + + Example: + >>> from physicsnemo.datapipes import Dataset, HDF5Reader, Normalize, ToDevice, Compose + >>> + >>> reader = HDF5Reader("data.h5", fields=["pressure", "velocity"]) + >>> transforms = Compose([ + ... ToDevice("cuda"), + ... Normalize(["pressure"], means={"pressure": 0.0}, stds={"pressure": 1.0}), + ... ]) + >>> + >>> dataset = Dataset(reader, transforms=transforms) + >>> sample = dataset[0] + """ + + def __init__( + self, + reader: Reader, + *, + transforms: Optional[Union[Transform, Sequence[Transform]]] = None, + device: Optional[Union[str, torch.device]] = None, + num_workers: int = 2, + ) -> None: + """ + Initialize the dataset. + + Args: + reader: Data reader providing raw samples. + transforms: Transform or sequence of transforms to apply. + If a sequence, they are composed in order. + device: Target device for automatic transfer (e.g., "cuda", "cuda:0"). + If None, no automatic transfer is performed (data stays on CPU). + When specified, data is transferred to this device before transforms. + num_workers: Number of worker threads for prefetching (default: 2). + + Raises: + TypeError: If reader is not a Reader instance. + """ + if not isinstance(reader, Reader): + raise TypeError( + f"reader must be a Reader instance, got {type(reader).__name__}" + ) + + self.reader = reader + self.num_workers = num_workers + self.target_device = torch.device(device) if device is not None else None + + # Handle transforms + if transforms is None: + self.transforms: Optional[Transform] = None + elif isinstance(transforms, Transform): + self.transforms = transforms + elif isinstance(transforms, Sequence): + if len(transforms) == 0: + self.transforms = None + elif len(transforms) == 1: + self.transforms = transforms[0] + else: + self.transforms = Compose(transforms) + else: + raise TypeError( + f"transforms must be Transform, Sequence[Transform], or None, " + f"got {type(transforms).__name__}" + ) + + # Share device with transforms so their internal state is on the right device + if self.target_device is not None and self.transforms is not None: + self.transforms.to(self.target_device) + + # Prefetch state - using thread-safe dict for results + # Key: index, Value: Future[_PrefetchResult] + self._prefetch_futures: dict[int, Future[_PrefetchResult]] = {} + self._executor: Optional[ThreadPoolExecutor] = None + + def _ensure_executor(self) -> ThreadPoolExecutor: + """Lazily create the thread pool executor.""" + if self._executor is None: + self._executor = ThreadPoolExecutor( + max_workers=self.num_workers, + thread_name_prefix="datapipe_prefetch", + ) + return self._executor + + def _load_and_transform( + self, + index: int, + stream: Optional[torch.cuda.Stream] = None, + ) -> _PrefetchResult: + """ + Load a sample and apply transforms. Called by worker threads. + + Args: + index: Sample index. + stream: Optional CUDA stream for GPU operations. + + Returns: + PrefetchResult with data, metadata, or error. + """ + result = _PrefetchResult(index=index) + + try: + # Load from reader (CPU, potentially slow IO) + data, metadata = self.reader[index] + + # Auto-transfer to target device if specified + if self.target_device is not None: + if stream is not None: + with torch.cuda.stream(stream): + data = data.to(self.target_device, non_blocking=True) + else: + data = data.to(self.target_device, non_blocking=True) + + # Apply transforms (data is now on target device if specified) + if self.transforms is not None: + if stream is not None: + with torch.cuda.stream(stream): + data = self.transforms(data) + # Record event for synchronization + result.event = torch.cuda.Event() + result.event.record(stream) + else: + data = self.transforms(data) + + result.data = data + result.metadata = metadata + + except Exception as e: + result.error = e + + return result + + def prefetch( + self, + index: int, + stream: Optional[torch.cuda.Stream] = None, + ) -> None: + """ + Start prefetching a sample asynchronously. + + The sample will be loaded in a background thread. If a CUDA stream + is provided, GPU operations happen on that stream. + + Call __getitem__ to retrieve the result (it will wait if needed). + + Args: + index: Sample index to prefetch. + stream: Optional CUDA stream for GPU operations. + """ + # Don't prefetch if already in flight + if index in self._prefetch_futures: + return + + executor = self._ensure_executor() + future = executor.submit(self._load_and_transform, index, stream) + self._prefetch_futures[index] = future + + def prefetch_batch( + self, + indices: Sequence[int], + streams: Optional[Sequence[torch.cuda.Stream]] = None, + ) -> None: + """ + Start prefetching multiple samples. + + Args: + indices: Sample indices to prefetch. + streams: Optional CUDA streams, one per index. If shorter than + indices, streams are cycled. If None, no streams used. + """ + for i, idx in enumerate(indices): + stream = None + if streams: + stream = streams[i % len(streams)] + self.prefetch(idx, stream=stream) + + def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: + """ + Get a transformed sample by index. + + If the index was prefetched, returns the prefetched result + (waiting for completion if necessary). Otherwise loads synchronously. + + Args: + index: Sample index. + + Returns: + Tuple of (TensorDict with transformed data, metadata dict). + + Raises: + IndexError: If index is out of range. + Exception: If prefetch failed, re-raises the error. + """ + # Check if prefetched + future = self._prefetch_futures.pop(index, None) + + if future is not None: + # Wait for prefetch to complete + result = future.result() + + if result.error is not None: + raise result.error + + # Sync stream if needed + if result.event is not None: + result.event.synchronize() + + return result.data, result.metadata + + # Not prefetched, load synchronously + data, metadata = self.reader[index] + + # Auto-transfer to target device if specified + if self.target_device is not None: + data = data.to(self.target_device, non_blocking=True) + + # Apply transforms + if self.transforms is not None: + data = self.transforms(data) + + return data, metadata + + def cancel_prefetch(self, index: Optional[int] = None) -> None: + """ + Cancel prefetch requests. + + Note: Already-running tasks will complete, but results are discarded. + + Args: + index: Specific index to cancel. If None, cancels all. + """ + if index is None: + # Cancel all - just clear the dict, let futures complete + self._prefetch_futures.clear() + else: + self._prefetch_futures.pop(index, None) + + def __len__(self) -> int: + """Return the number of samples in the dataset.""" + return len(self.reader) + + def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: + """ + Iterate over all samples. + + Note: This does NOT automatically prefetch. For prefetched iteration, + use the DataLoader which manages prefetching strategy. + """ + for i in range(len(self)): + yield self[i] + + @property + def field_names(self) -> list[str]: + """List of field names in samples (from reader).""" + return self.reader.field_names + + @property + def prefetch_count(self) -> int: + """Number of items currently being prefetched.""" + return len(self._prefetch_futures) + + def close(self) -> None: + """Close the dataset and stop prefetching.""" + # Wait for any in-flight prefetch tasks to complete before shutdown. + # This prevents "cannot schedule new futures after shutdown" errors + # from libraries like zarr that use async I/O internally. + for future in self._prefetch_futures.values(): + try: + future.result(timeout=30.0) # Wait up to 30s per task + except Exception: # noqa: BLE001, S110 + pass # Ignore errors during shutdown + + self._prefetch_futures.clear() + + if self._executor is not None: + self._executor.shutdown(wait=True) + self._executor = None + + self.reader.close() + + def __enter__(self) -> "Dataset": + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit.""" + self.close() + + def __repr__(self) -> str: + transform_str = repr(self.transforms) if self.transforms else "None" + return f"Dataset(\n reader={self.reader},\n transforms={transform_str}\n)" diff --git a/physicsnemo/datapipes/core/readers/__init__.py b/physicsnemo/datapipes/core/readers/__init__.py new file mode 100644 index 0000000000..dbcb119395 --- /dev/null +++ b/physicsnemo/datapipes/core/readers/__init__.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Readers module - Data source interfaces for loading raw data. + +Readers are responsible for: +- Loading data from various sources (HDF5, Zarr, NumPy, etc.) +- Converting to torch tensors +- Async CPU->GPU transfers with optional prefetching +- Returning Sample objects ready for the transform pipeline +""" + +from physicsnemo.datapipes.core.readers.base import Reader +from physicsnemo.datapipes.core.readers.hdf5 import HDF5Reader +from physicsnemo.datapipes.core.readers.numpy import NumpyReader +from physicsnemo.datapipes.core.readers.tensorstore_zarr import TensorStoreZarrReader +from physicsnemo.datapipes.core.readers.vtk import VTKReader +from physicsnemo.datapipes.core.readers.zarr import ZarrReader + +__all__ = [ + "Reader", + "HDF5Reader", + "ZarrReader", + "NumpyReader", + "VTKReader", + "TensorStoreZarrReader", +] diff --git a/physicsnemo/datapipes/core/readers/base.py b/physicsnemo/datapipes/core/readers/base.py new file mode 100644 index 0000000000..bbddf06a76 --- /dev/null +++ b/physicsnemo/datapipes/core/readers/base.py @@ -0,0 +1,226 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Reader base class - Abstract interface for data sources. + +Readers are simple, transactional data loaders. They load data from sources +and return TensorDict instances with CPU tensors plus separate metadata dicts. +Device transfers and threading are handled elsewhere (Dataset and DataLoader). +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Iterator + +import torch +from tensordict import TensorDict + + +class Reader(ABC): + """ + Abstract base class for data readers. + + Readers are intentionally simple and transactional: + - Load data from a source (file, database, etc.) + - Return (TensorDict, metadata_dict) tuples with CPU tensors + - No threading, no prefetching, no device transfers + + This design makes custom readers easy to implement. Users only need to: + 1. Implement `_load_sample(index)` to load raw data + 2. Implement `__len__()` to return dataset size + + Device transfers are handled automatically by Dataset (if device parameter set). + Threading/prefetching is handled by the DataLoader. + + Example custom reader: + >>> class MyReader(Reader): + ... def __init__(self, path: str, **kwargs): + ... super().__init__(**kwargs) + ... self.data = load_my_data(path) + ... + ... def _load_sample(self, index: int) -> dict[str, torch.Tensor]: + ... return {"x": torch.from_numpy(self.data[index])} + ... + ... def __len__(self) -> int: + ... return len(self.data) + + Subclasses must implement: + - _load_sample(index: int) -> dict[str, torch.Tensor] + - __len__() -> int + + Optionally override: + - _get_field_names() -> list[str] + - _get_sample_metadata(index: int) -> dict[str, Any] + - close() + """ + + def __init__( + self, + *, + pin_memory: bool = False, + include_index_in_metadata: bool = True, + coordinated_subsampling: dict[str, Any] | None = None, + ) -> None: + """ + Initialize the reader. + + Args: + pin_memory: If True, place tensors in pinned (page-locked) memory. + This enables faster async CPU→GPU transfers later. + Only use if you plan to move data to GPU (default: False). + include_index_in_metadata: If True, include sample index in metadata + (default: True). + coordinated_subsampling: Optional dict to configure coordinated + subsampling at construction time. If provided, must contain: + - ``n_points``: Number of points to read from each target tensor + - ``target_keys``: List of tensor keys to apply subsampling to + This allows configuration via Hydra. Readers that don't support + coordinated subsampling will ignore this parameter. + """ + self.pin_memory = pin_memory + self.include_index_in_metadata = include_index_in_metadata + self._coordinated_subsampling_config = coordinated_subsampling + + @abstractmethod + def _load_sample(self, index: int) -> dict[str, torch.Tensor]: + """ + Load raw data for a single sample. + + This is the main method to implement. Load data from your source + and return it as a dictionary of CPU tensors. + + Args: + index: Sample index (0 to len-1). + + Returns: + Dictionary mapping field names to CPU tensors. + + Raises: + IndexError: If index is out of range. + """ + raise NotImplementedError + + @abstractmethod + def __len__(self) -> int: + """Return the number of samples in the dataset.""" + raise NotImplementedError + + def _get_field_names(self) -> list[str]: + """ + Return the list of field names in samples. + + Override this to provide field names without loading a sample. + Default implementation loads sample 0 and extracts keys. + """ + if len(self) == 0: + return [] + data = self._load_sample(0) + return list(data.keys()) + + def _get_sample_metadata(self, index: int) -> dict[str, Any]: + """ + Return metadata for a sample. + + Override this to provide source-specific metadata (filenames, etc.). + Default implementation returns empty dict (index added separately). + + Args: + index: Sample index. + + Returns: + Dictionary of metadata (not tensors). + """ + return {} + + @property + def _supports_coordinated_subsampling(self) -> bool: + """ + Return True if this reader supports coordinated subsampling. + + Override this property in subclasses that implement coordinated subsampling. + """ + return False + + @property + def field_names(self) -> list[str]: + """List of field names available in samples.""" + return self._get_field_names() + + def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: + """ + Load and return a single sample. + + Args: + index: Sample index. Supports negative indexing. + + Returns: + Tuple of (TensorDict with CPU tensors, metadata dict). + + Raises: + IndexError: If index is out of range. + """ + # Handle negative indexing + if index < 0: + index = len(self) + index + if index < 0 or index >= len(self): + raise IndexError( + f"Index {index} out of range for reader with {len(self)} samples" + ) + + # Load data + data_dict = self._load_sample(index) + + # Build metadata + metadata = self._get_sample_metadata(index) + if self.include_index_in_metadata: + metadata["index"] = index + + # Pin memory if requested + if self.pin_memory: + data_dict = {k: v.pin_memory() for k, v in data_dict.items()} + + # Create TensorDict + data = TensorDict(data_dict, device=torch.device("cpu")) + + return data, metadata + + def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: + """Iterate over all samples.""" + for i in range(len(self)): + yield self[i] + + def close(self) -> None: + """ + Clean up resources (file handles, connections, etc.). + + Override this in subclasses that hold open resources. + """ + pass + + def __enter__(self) -> "Reader": + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit.""" + self.close() + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(len={len(self)}, pin_memory={self.pin_memory})" + ) diff --git a/physicsnemo/datapipes/core/readers/hdf5.py b/physicsnemo/datapipes/core/readers/hdf5.py new file mode 100644 index 0000000000..74e6086b3e --- /dev/null +++ b/physicsnemo/datapipes/core/readers/hdf5.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +HDF5Reader - Read data from HDF5 files. + +Supports reading from single HDF5 files or directories of HDF5 files. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Optional, Union + +import torch + +try: + import h5py + + HAS_H5PY = True +except ImportError: + HAS_H5PY = False + +from physicsnemo.datapipes.core.readers.base import Reader +from physicsnemo.datapipes.core.registry import register_reader + + +@register_reader() +class HDF5Reader(Reader): + """ + Read samples from HDF5 files. + + Supports two modes: + 1. Single file with samples indexed along first dimension of datasets + 2. Directory of HDF5 files, one sample per file + + Example (single file): + >>> # File structure: data.h5 with datasets "pressure" (N, 100), "velocity" (N, 100, 3) + >>> reader = HDF5Reader("data.h5", fields=["pressure", "velocity"]) + >>> sample = reader[0] # Returns Sample with pressure[100] and velocity[100, 3] + + Example (directory): + >>> # Directory with sample_0.h5, sample_1.h5, ... + >>> reader = HDF5Reader("data_dir/", file_pattern="sample_*.h5") + >>> sample = reader[0] # Loads all datasets from sample_0.h5 + """ + + def __init__( + self, + path: Path | str, + *, + fields: Optional[list[str]] = None, + file_pattern: str = "*.h5", + index_key: Optional[str] = None, + pin_memory: bool = False, + include_index_in_metadata: bool = True, + ) -> None: + """ + Initialize the HDF5 reader. + + Args: + path: Path to HDF5 file or directory containing HDF5 files. + fields: List of dataset names to load. If None, loads all datasets. + file_pattern: Glob pattern for finding files (directory mode only). + index_key: If provided, use this dataset to determine sample count + instead of inferring from first dimension. + pin_memory: If True, place tensors in pinned memory for faster GPU transfer. + include_index_in_metadata: If True, include sample index in metadata. + + Raises: + ImportError: If h5py is not installed. + FileNotFoundError: If path doesn't exist. + ValueError: If no HDF5 files found in directory. + """ + if not HAS_H5PY: + raise ImportError( + "h5py is required for HDF5Reader. Install with: pip install h5py" + ) + + super().__init__( + pin_memory=pin_memory, + include_index_in_metadata=include_index_in_metadata, + ) + + self.path = Path(path) + self.fields = fields + self.file_pattern = file_pattern + self.index_key = index_key + + if not self.path.exists(): + raise FileNotFoundError(f"Path not found: {self.path}") + + # Determine mode: single file or directory + self._is_directory = self.path.is_dir() + + if self._is_directory: + # Directory mode: each file is a sample + self._files = sorted(self.path.glob(file_pattern)) + if not self._files: + raise ValueError( + f"No files matching '{file_pattern}' found in {self.path}" + ) + self._length = len(self._files) + self._h5_file = None + + # Discover fields from first file + if self.fields is None: + with h5py.File(self._files[0], "r") as f: + self.fields = [ + k for k in f.keys() if isinstance(f[k], h5py.Dataset) + ] + else: + # Single file mode: samples indexed along first dimension + self._files = None + self._h5_file = h5py.File(self.path, "r") + + # Discover fields + if self.fields is None: + self.fields = [ + k + for k in self._h5_file.keys() + if isinstance(self._h5_file[k], h5py.Dataset) + ] + + # Determine length + if self.index_key is not None: + self._length = self._h5_file[self.index_key].shape[0] + elif self.fields: + self._length = self._h5_file[self.fields[0]].shape[0] + else: + self._length = 0 + + def _load_sample(self, index: int) -> dict[str, torch.Tensor]: + """Load a single sample from HDF5.""" + data = {} + + if self._is_directory: + # Directory mode: load all datasets from the file + file_path = self._files[index] + with h5py.File(file_path, "r") as f: + for field in self.fields: + if field not in f: + raise KeyError( + f"Field '{field}' not found in {file_path}. " + f"Available: {list(f.keys())}" + ) + arr = f[field][:] + data[field] = torch.from_numpy(arr) + else: + # Single file mode: index into datasets + for field in self.fields: + if field not in self._h5_file: + raise KeyError( + f"Field '{field}' not found in {self.path}. " + f"Available: {list(self._h5_file.keys())}" + ) + arr = self._h5_file[field][index] + data[field] = torch.from_numpy(arr) + + return data + + def __len__(self) -> int: + """Return number of samples.""" + return self._length + + def _get_field_names(self) -> list[str]: + """Return field names.""" + return self.fields if self.fields else [] + + def _get_sample_metadata(self, index: int) -> dict[str, Any]: + """Return metadata for a sample including source file info.""" + if self._is_directory: + return { + "source_file": str(self._files[index]), + "source_filename": self._files[index].name, + } + else: + return { + "source_file": str(self.path), + "source_filename": self.path.name, + } + + def close(self) -> None: + """Close HDF5 file handle.""" + super().close() + if self._h5_file is not None: + self._h5_file.close() + self._h5_file = None + + def __repr__(self) -> str: + mode = "directory" if self._is_directory else "file" + return ( + f"HDF5Reader(" + f"path={self.path}, " + f"mode={mode}, " + f"len={len(self)}, " + f"fields={self.fields})" + ) diff --git a/physicsnemo/datapipes/core/readers/numpy.py b/physicsnemo/datapipes/core/readers/numpy.py new file mode 100644 index 0000000000..b378b7df01 --- /dev/null +++ b/physicsnemo/datapipes/core/readers/numpy.py @@ -0,0 +1,315 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +NumpyReader - Read data from NumPy .npz files. + +Supports reading from single .npz files or directories of .npz files. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Optional, Union + +import numpy as np +import torch + +from physicsnemo.datapipes.core.readers.base import Reader +from physicsnemo.datapipes.core.registry import register_reader + + +@register_reader() +class NumpyReader(Reader): + """ + Read samples from NumPy .npz files. + + Supports two modes: + 1. Single .npz file: samples indexed along first dimension of each array + 2. Directory of .npz files: one sample per file + + Example (single .npz): + >>> # data.npz with arrays "positions" (N, 100, 3), "features" (N, 100) + >>> reader = NumpyReader("data.npz", fields=["positions", "features"]) + >>> sample = reader[0] + >>> # Or load all arrays: + >>> reader = NumpyReader("data.npz") # fields=None loads all + + Example (directory): + >>> # Directory with sample_0.npz, sample_1.npz, ... + >>> reader = NumpyReader("data_dir/", file_pattern="sample_*.npz") + >>> sample = reader[0] + """ + + def __init__( + self, + path: Union[str, Path], + *, + fields: Optional[list[str]] = None, + default_values: Optional[dict[str, torch.Tensor]] = None, + file_pattern: str = "*.npz", + index_key: Optional[str] = None, + pin_memory: bool = False, + include_index_in_metadata: bool = True, + coordinated_subsampling: Optional[dict[str, Any]] = None, + ) -> None: + """ + Initialize the NumPy reader. + + Args: + path: Path to .npz file or directory of .npz files. + fields: List of array names to load. If None, loads all available + arrays from the file. + default_values: Dictionary mapping field names to default tensors. + If a field in ``fields`` is not found in the file but has an + entry here, the default tensor is used instead of raising an + error. Useful for optional fields. + file_pattern: Glob pattern for finding files (directory mode). + index_key: If provided, use this array to determine sample count. + pin_memory: If True, place tensors in pinned memory for faster GPU transfer. + include_index_in_metadata: If True, include sample index in metadata. + coordinated_subsampling: Optional dict to configure coordinated + subsampling (directory mode only). If provided, must contain + ``n_points`` (int) and ``target_keys`` (list of str). + + Raises: + FileNotFoundError: If path doesn't exist. + ValueError: If no files found in directory or unsupported file type. + """ + super().__init__( + pin_memory=pin_memory, + include_index_in_metadata=include_index_in_metadata, + coordinated_subsampling=coordinated_subsampling, + ) + + self.path = Path(path).expanduser().resolve() + self._user_fields = fields + self.default_values = default_values or {} + self.file_pattern = file_pattern + self.index_key = index_key + + if not self.path.exists(): + raise FileNotFoundError(f"Path not found: {self.path}") + + # Determine mode based on path + self._mode: str # "single" or "directory" + self._files: Optional[list[Path]] = None + self._data: Optional[np.lib.npyio.NpzFile] = None + self._available_fields: list[str] = [] + + if self.path.is_dir(): + self._setup_directory_mode() + elif self.path.suffix == ".npz": + self._setup_single_file_mode() + else: + raise ValueError( + f"Unsupported file type: {self.path.suffix}. " + f"Expected .npz file or directory of .npz files." + ) + + def _setup_directory_mode(self) -> None: + """Set up reader for directory of .npz files.""" + self._mode = "directory" + self._files = sorted(self.path.glob(self.file_pattern)) + if not self._files: + raise ValueError( + f"No files matching '{self.file_pattern}' found in {self.path}" + ) + self._length = len(self._files) + + # Discover available fields from first file + with np.load(self._files[0]) as npz: + self._available_fields = list(npz.files) + + def _setup_single_file_mode(self) -> None: + """Set up reader for single .npz file.""" + self._mode = "single" + self._data = np.load(self.path) + self._available_fields = list(self._data.files) + + # Determine length from index_key or first field + if self.index_key is not None: + self._length = self._data[self.index_key].shape[0] + elif self._available_fields: + self._length = self._data[self._available_fields[0]].shape[0] + else: + self._length = 0 + + @property + def fields(self) -> list[str]: + """Fields that will be loaded (user-specified or all available).""" + if self._user_fields is not None: + return self._user_fields + return self._available_fields + + def _select_random_sections_from_slice( + self, + slice_start: int, + slice_stop: int, + n_points: int, + ) -> slice: + """ + Select a random contiguous slice from a range. + + Args: + slice_start: Start index of the available range. + slice_stop: Stop index of the available range (exclusive). + n_points: Number of points to sample. + + Returns: + A slice object representing the random contiguous section. + + Raises: + ValueError: If the range is smaller than n_points. + """ + total_points = slice_stop - slice_start + + if total_points < n_points: + raise ValueError( + f"Slice size {total_points} is less than the number of points " + f"{n_points} requested for subsampling" + ) + + start = np.random.randint(slice_start, slice_stop - n_points + 1) + return slice(start, start + n_points) + + def _load_from_npz( + self, + npz: np.lib.npyio.NpzFile, + index: Optional[int] = None, + file_path: Optional[Path] = None, + ) -> dict[str, torch.Tensor]: + """ + Load data from an npz file. + + Args: + npz: The loaded npz file object. + index: Sample index to load (for single file mode with indexed arrays). + None for directory mode (load entire arrays). + file_path: Path to the file (for error messages). + + Returns: + Dictionary mapping field names to tensors. + """ + data = {} + fields_to_load = self.fields + + # Check for missing required fields + required_fields = set(fields_to_load) - set(self.default_values.keys()) + missing_fields = required_fields - set(npz.files) + if missing_fields: + path_str = str(file_path) if file_path else str(self.path) + raise KeyError( + f"Required fields {missing_fields} not found in {path_str}. " + f"Available: {list(npz.files)}" + ) + + # Determine subsample slice if coordinated subsampling is enabled + subsample_slice = None + target_keys_set = set() + if self._coordinated_subsampling_config is not None: + n_points = self._coordinated_subsampling_config["n_points"] + target_keys_set = set(self._coordinated_subsampling_config["target_keys"]) + + # Find slice from first available target key + for field in target_keys_set: + if field in npz.files: + array_shape = npz[field].shape[0] + subsample_slice = self._select_random_sections_from_slice( + 0, array_shape, n_points + ) + break + + # Load each field + for field in fields_to_load: + if field in npz.files: + arr = npz[field] + + # Apply indexing if provided (single file mode) + if index is not None: + arr = arr[index] + + # Apply subsampling if this field is a target + if subsample_slice is not None and field in target_keys_set: + arr = arr[subsample_slice] + elif index is None: + # Directory mode: load full array + arr = arr[:] + + data[field] = torch.from_numpy(np.array(arr)) + + elif field in self.default_values: + data[field] = self.default_values[field].clone() + + return data + + def _load_sample(self, index: int) -> dict[str, torch.Tensor]: + """Load a single sample.""" + if self._mode == "directory": + file_path = self._files[index] + with np.load(file_path) as npz: + return self._load_from_npz(npz, index=None, file_path=file_path) + else: # single + return self._load_from_npz(self._data, index=index) + + def __len__(self) -> int: + """Return number of samples.""" + return self._length + + def _get_field_names(self) -> list[str]: + """Return field names that will be loaded.""" + return self.fields + + def _get_sample_metadata(self, index: int) -> dict[str, Any]: + """Return metadata for a sample including source file info.""" + if self._mode == "directory": + return { + "source_file": str(self._files[index]), + "source_filename": self._files[index].name, + } + else: + return { + "source_file": str(self.path), + "source_filename": self.path.name, + } + + @property + def _supports_coordinated_subsampling(self) -> bool: + """NumPy reader supports coordinated subsampling in directory mode.""" + return self._mode == "directory" + + def close(self) -> None: + """Close file handles.""" + super().close() + if self._data is not None: + if hasattr(self._data, "close"): + self._data.close() + self._data = None + + def __repr__(self) -> str: + subsample_info = "" + if self._coordinated_subsampling_config is not None: + cfg = self._coordinated_subsampling_config + subsample_info = f", subsampling={cfg['n_points']} points" + + return ( + f"NumpyReader(" + f"path={self.path}, " + f"mode={self._mode}, " + f"len={len(self)}, " + f"fields={self.fields}" + f"{subsample_info})" + ) diff --git a/physicsnemo/datapipes/core/readers/tensorstore_zarr.py b/physicsnemo/datapipes/core/readers/tensorstore_zarr.py new file mode 100644 index 0000000000..25f36ca757 --- /dev/null +++ b/physicsnemo/datapipes/core/readers/tensorstore_zarr.py @@ -0,0 +1,379 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +TensorStoreZarrReader - High-performance async reader for Zarr files using TensorStore. + +Provides faster I/O than standard Zarr library through async operations and +optimized caching. Supports coordinated subsampling for large arrays. +""" + +from __future__ import annotations + +import importlib +import json +from pathlib import Path +from typing import Any, Optional, Union + +import numpy as np +import torch + +from physicsnemo.core.version_check import check_version_spec +from physicsnemo.datapipes.core.readers.base import Reader +from physicsnemo.datapipes.core.registry import register_reader + +# Check if tensorstore is available +TENSORSTORE_AVAILABLE = check_version_spec("tensorstore", hard_fail=False) + +if TENSORSTORE_AVAILABLE: + ts = importlib.import_module("tensorstore") + + +@register_reader() +class TensorStoreZarrReader(Reader): + r""" + High-performance async reader for Zarr files using TensorStore. + + This reader provides faster I/O than the standard ZarrReader through async + operations, optimized caching, and concurrent data fetching. It's particularly + beneficial for large datasets on networked storage or cloud storage. + + This is a drop-in replacement for ZarrReader with identical interface. + Each Zarr group in the directory represents one sample. + + Example: + >>> # Directory with sample_0.zarr, sample_1.zarr, ... + >>> reader = TensorStoreZarrReader("data_dir/", group_pattern="sample_*.zarr") + >>> sample = reader[0] + + >>> # Load only specific fields: + >>> reader = TensorStoreZarrReader("data_dir/", fields=["positions", "velocity"]) + >>> sample = reader[0] + + >>> # With coordinated subsampling for large arrays: + >>> reader = TensorStoreZarrReader( + ... "data_dir/", + ... coordinated_subsampling={ + ... "n_points": 50000, + ... "target_keys": ["volume_coords", "volume_fields"], + ... } + ... ) + >>> sample = reader[0] + + Performance Tips: + - Increase ``cache_bytes_limit`` for better performance on repeated access + - Increase ``data_copy_concurrency`` and ``file_io_concurrency`` for + parallel workloads + - Use coordinated subsampling when reading subsets of large arrays + """ + + def __init__( + self, + path: Union[str, Path], + *, + fields: Optional[list[str]] = None, + default_values: Optional[dict[str, torch.Tensor]] = None, + group_pattern: str = "*.zarr", + cache_bytes_limit: int = 10_000_000, + data_copy_concurrency: int = 72, + file_io_concurrency: int = 72, + pin_memory: bool = False, + include_index_in_metadata: bool = True, + coordinated_subsampling: Optional[dict[str, Any]] = None, + ) -> None: + """ + Initialize the TensorStore Zarr reader. + + Args: + path: Path to directory containing Zarr groups. + fields: List of array names to load. If None, loads all available + arrays from each group. + default_values: Dictionary mapping field names to default tensors. + If a field in ``fields`` is not found in the file but has an + entry here, the default tensor is used instead of raising an + error. Useful for optional fields. + group_pattern: Glob pattern for finding Zarr groups. + cache_bytes_limit: Total cache size in bytes (default: 10 MB). + data_copy_concurrency: Limit for concurrent data copy operations. + file_io_concurrency: Limit for concurrent file I/O operations. + pin_memory: If True, place tensors in pinned memory. + include_index_in_metadata: If True, include sample index in metadata. + coordinated_subsampling: Optional dict to configure coordinated + subsampling. If provided, must contain ``n_points`` (int) and + ``target_keys`` (list of str). + + Raises: + ImportError: If TensorStore is not installed. + FileNotFoundError: If path doesn't exist. + ValueError: If no Zarr groups found. + """ + if not TENSORSTORE_AVAILABLE: + raise ImportError( + "TensorStore is required for TensorStoreZarrReader but is not installed.\n" + "Install it with: pip install tensorstore\n" + "See https://google.github.io/tensorstore/ for more information." + ) + + super().__init__( + pin_memory=pin_memory, + include_index_in_metadata=include_index_in_metadata, + coordinated_subsampling=coordinated_subsampling, + ) + + self.path = Path(path).expanduser().resolve() + self._user_fields = fields + self.default_values = default_values or {} + self.group_pattern = group_pattern + + if not self.path.exists(): + raise FileNotFoundError(f"Path not found: {self.path}") + + if not self.path.is_dir(): + raise ValueError( + f"Path must be a directory containing Zarr groups: {self.path}" + ) + + # Find all Zarr groups + self._groups = sorted( + [ + p + for p in self.path.glob(group_pattern) + if p.is_dir() + and ((p / ".zgroup").exists() or (p / "zarr.json").exists()) + ] + ) + + if not self._groups: + raise ValueError( + f"No Zarr groups matching '{group_pattern}' found in {self.path}" + ) + + self._length = len(self._groups) + + # Discover available fields from first group + self._available_fields = self._discover_fields(self._groups[0]) + + # Create TensorStore context with caching config + self._context = ts.Context( + { + "cache_pool": {"total_bytes_limit": cache_bytes_limit}, + "data_copy_concurrency": {"limit": data_copy_concurrency}, + "file_io_concurrency": {"limit": file_io_concurrency}, + } + ) + + # Spec template for opening Zarr arrays + self._spec_template = { + "driver": "zarr", + "kvstore": { + "driver": "file", + "path": None, + }, + } + + def _discover_fields(self, group_path: Path) -> list[str]: + """Discover array names in a Zarr group (v2 or v3 format).""" + fields = [] + + # List subdirectories that are zarr arrays + for item in group_path.iterdir(): + if not item.is_dir(): + continue + + # Zarr v2: arrays have .zarray metadata file + if (item / ".zarray").exists(): + fields.append(item.name) + # Zarr v3: arrays have zarr.json with node_type="array" + elif (item / "zarr.json").exists(): + try: + with open(item / "zarr.json") as f: + metadata = json.load(f) + if metadata.get("node_type") == "array": + fields.append(item.name) + except (json.JSONDecodeError, OSError): + # Skip malformed or unreadable metadata + pass + + return sorted(fields) + + @property + def fields(self) -> list[str]: + """Fields that will be loaded (user-specified or all available).""" + if self._user_fields is not None: + return self._user_fields + return self._available_fields + + def _read_attributes(self, group_path: Path) -> dict[str, Any]: + """Read attributes from a Zarr group (v2 or v3).""" + store_spec = {"driver": "file", "path": str(group_path)} + store = ts.KvStore.open(store_spec).result() + + keys = store.list().result() + + # Try Zarr v3 format first + if b"/zarr.json" in keys: + zarr_json = store.read(b"/zarr.json").result() + metadata = json.loads(zarr_json.value) + if "attributes" in metadata: + return {k: torch.tensor(v) for k, v in metadata["attributes"].items()} + return {} + + # Try Zarr v2 format + elif b"/.zattrs" in keys: + zarr_attrs = store.read(b"/.zattrs").result() + metadata = json.loads(zarr_attrs.value) + return {k: torch.tensor(v) for k, v in metadata.items()} + + return {} + + def _select_random_sections_from_slice( + self, + slice_start: int, + slice_stop: int, + n_points: int, + ) -> slice: + """Select a random contiguous slice from a range.""" + total_points = slice_stop - slice_start + + if total_points < n_points: + raise ValueError( + f"Slice size {total_points} is less than the number of points " + f"{n_points} requested for subsampling" + ) + + start = np.random.randint(slice_start, slice_stop - n_points + 1) + return slice(start, start + n_points) + + def _load_sample(self, index: int) -> dict[str, torch.Tensor]: + """Load a single sample from a Zarr group using TensorStore.""" + group_path = self._groups[index] + + # Read attributes (stored as tensors in sample) + attributes = self._read_attributes(group_path) + + # Determine which fields to read + fields_to_load = self.fields + fields_from_arrays = set(fields_to_load) - set(attributes.keys()) + + # Check for missing required fields using cached available fields + # (discovered once during __init__ from the first group) + available = set(self._available_fields) + required_fields = fields_from_arrays - set(self.default_values.keys()) + missing_fields = required_fields - available + if missing_fields: + raise KeyError( + f"Required fields {missing_fields} not found in {group_path}. " + f"Available: {list(available)}" + ) + + # Determine subsample slice if coordinated subsampling is enabled + subsample_slice = None + target_keys_set = set() + if self._coordinated_subsampling_config is not None: + n_points = self._coordinated_subsampling_config["n_points"] + target_keys_set = set(self._coordinated_subsampling_config["target_keys"]) + + # Open all array stores asynchronously + read_futures = {} + for key in fields_from_arrays: + if key not in available: + continue + + spec = { + "driver": "auto", + "kvstore": { + "driver": "file", + "path": str(group_path / key), + }, + } + read_futures[key] = ts.open( + spec, create=False, open=True, context=self._context + ) + + # Wait for opens to complete + stores = {key: future.result() for key, future in read_futures.items()} + + # Determine subsample slice if needed + if subsample_slice is None and self._coordinated_subsampling_config is not None: + for key in target_keys_set: + if key in stores: + array_shape = stores[key].shape[0] + subsample_slice = self._select_random_sections_from_slice( + 0, array_shape, n_points + ) + break + + # Trigger async reads + tensor_futures = {} + for key in fields_from_arrays: + if key not in stores: + continue + + # Apply subsampling if this key is a target + if subsample_slice is not None and key in target_keys_set: + tensor_futures[key] = stores[key][subsample_slice].read() + else: + tensor_futures[key] = stores[key][:].read() + + # Wait for reads and convert to torch tensors + data = { + key: torch.as_tensor(future.result(), dtype=torch.float32) + for key, future in tensor_futures.items() + } + + # Add attributes + data.update(attributes) + + # Add default values for missing optional fields + for key, default_value in self.default_values.items(): + if key not in data: + data[key] = default_value.clone() + + return data + + def __len__(self) -> int: + """Return number of samples.""" + return self._length + + def _get_field_names(self) -> list[str]: + """Return field names that will be loaded.""" + return self.fields + + def _get_sample_metadata(self, index: int) -> dict[str, Any]: + """Return metadata for a sample.""" + return { + "source_file": str(self._groups[index]), + "source_filename": self._groups[index].name, + } + + @property + def _supports_coordinated_subsampling(self) -> bool: + """TensorStore Zarr reader supports coordinated subsampling.""" + return True + + def __repr__(self) -> str: + subsample_info = "" + if self._coordinated_subsampling_config is not None: + cfg = self._coordinated_subsampling_config + subsample_info = f", subsampling={cfg['n_points']} points" + + return ( + f"TensorStoreZarrReader(" + f"path={self.path}, " + f"len={len(self)}, " + f"fields={self.fields}" + f"{subsample_info})" + ) diff --git a/physicsnemo/datapipes/core/readers/vtk.py b/physicsnemo/datapipes/core/readers/vtk.py new file mode 100644 index 0000000000..e4bb147b9f --- /dev/null +++ b/physicsnemo/datapipes/core/readers/vtk.py @@ -0,0 +1,307 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VTKReader - Read data from VTK format files (.stl, .vtp, .vtu). + +Supports reading mesh data from directories containing VTK files using PyVista. +""" + +from __future__ import annotations + +import importlib +from pathlib import Path +from typing import Any, Optional, Union + +import numpy as np +import torch + +from physicsnemo.core.version_check import check_version_spec +from physicsnemo.datapipes.core.readers.base import Reader +from physicsnemo.datapipes.core.registry import register_reader + +# Check if pyvista is available +PYVISTA_AVAILABLE = check_version_spec("pyvista", hard_fail=False) + +if PYVISTA_AVAILABLE: + pv = importlib.import_module("pyvista") + + +@register_reader() +class VTKReader(Reader): + r""" + Read samples from VTK format files (.stl, .vtp, .vtu). + + This reader loads mesh data from directories where each subdirectory contains + VTK files representing one sample. Supports STL (surface meshes), VTP + (PolyData), and VTU (UnstructuredGrid) formats. + + Requires PyVista to be installed. If PyVista is not available, attempting + to instantiate this reader will raise an ImportError with installation + instructions. + + Example: + >>> # Directory structure: + >>> # data/ + >>> # sample_0/ + >>> # geometry.stl + >>> # surface.vtp + >>> # sample_1/ + >>> # geometry.stl + >>> # surface.vtp + >>> # ... + >>> + >>> reader = VTKReader( + ... "data/", + ... keys_to_read=["stl_coordinates", "stl_faces", "surface_normals"], + ... ) + >>> sample = reader[0] + >>> print(sample["stl_coordinates"].shape) # (N, 3) + + Available Keys: + From .stl files: + - ``stl_coordinates``: Vertex coordinates, shape :math:`(N, 3)` + - ``stl_faces``: Face indices (flattened), shape :math:`(M*3,)` + - ``stl_centers``: Face centers, shape :math:`(M, 3)` + - ``surface_normals``: Face normals, shape :math:`(M, 3)` + + From .vtp files: + - ``surface_mesh_centers``: Cell centers + - ``surface_normals``: Cell normals + - ``surface_mesh_sizes``: Cell areas + - Additional fields from the VTP file + + Note: + VTK files are typically small enough to fit in memory, so coordinated + subsampling is not supported. Use transforms for downsampling if needed. + """ + + def __init__( + self, + path: Union[str, Path], + *, + keys_to_read: Optional[list[str]] = None, + exclude_patterns: Optional[list[str]] = None, + pin_memory: bool = False, + include_index_in_metadata: bool = True, + ) -> None: + """ + Initialize the VTK reader. + + Args: + path: Path to directory containing subdirectories with VTK files. + keys_to_read: List of keys to extract from VTK files. + If None, extracts all available data. + exclude_patterns: List of filename patterns to exclude (e.g., ["single_solid"]). + pin_memory: If True, place tensors in pinned memory for faster GPU transfer. + include_index_in_metadata: If True, include sample index in metadata. + + Raises: + ImportError: If PyVista is not installed. + FileNotFoundError: If path doesn't exist. + ValueError: If no valid VTK directories found. + """ + if not PYVISTA_AVAILABLE: + raise ImportError( + "PyVista is required for VTKReader but is not installed.\n" + "Install it with: pip install pyvista\n" + "See https://docs.pyvista.org/getting-started/installation.html " + "for more information." + ) + + super().__init__( + pin_memory=pin_memory, + include_index_in_metadata=include_index_in_metadata, + ) + + self.path = Path(path) + self.keys_to_read = keys_to_read + self.exclude_patterns = exclude_patterns or ["single_solid"] + + if not self.path.exists(): + raise FileNotFoundError(f"Path not found: {self.path}") + + if not self.path.is_dir(): + raise ValueError(f"Path must be a directory: {self.path}") + + # Find all subdirectories containing VTK files + self._directories = [] + for subdir in self.path.iterdir(): + if subdir.is_dir() and self._is_vtk_directory(subdir): + self._directories.append(subdir) + + self._directories = sorted(self._directories) + + if not self._directories: + raise ValueError( + f"No directories containing VTK files found in {self.path}" + ) + + self._length = len(self._directories) + + # Supported file keys mapped to file extensions + self._stl_keys = { + "stl_coordinates", + "stl_centers", + "stl_faces", + "stl_areas", + "surface_normals", + } + self._vtp_keys = { + "surface_mesh_centers", + "surface_normals", + "surface_mesh_sizes", + "CpMeanTrim", + "pMeanTrim", + "wallShearStressMeanTrim", + } + self._vtu_keys = { + "volume_mesh_centers", + "volume_fields", + } + + def _is_vtk_directory(self, directory: Path) -> bool: + """Check if a directory contains VTK files.""" + vtk_extensions = {".stl", ".vtp", ".vtu", ".vtk"} + for file in directory.iterdir(): + if file.suffix in vtk_extensions: + return True + return False + + def _get_file_by_extension(self, directory: Path, extension: str) -> Optional[Path]: + """Get the first file with the given extension, excluding patterns.""" + for file in directory.iterdir(): + if file.suffix == extension: + # Check if any exclude pattern is in the filename + if not any(pattern in file.name for pattern in self.exclude_patterns): + return file + return None + + def _read_stl_data(self, stl_path: Path) -> dict[str, torch.Tensor]: + """Read data from an STL file.""" + mesh = pv.read(stl_path) + + data = {} + + # Extract faces (reshape from flat array to triangles) + faces = mesh.faces.reshape(-1, 4) + faces = faces[:, 1:] # Remove the first column (always 3 for triangles) + data["stl_faces"] = torch.from_numpy(faces.flatten()) + + # Extract coordinates + data["stl_coordinates"] = torch.from_numpy(mesh.points) + + # Extract normals + data["surface_normals"] = torch.from_numpy(mesh.cell_normals) + + # Compute face centers (for stl_centers) + # Each face has 3 vertices, compute the mean + vertices = mesh.points + face_indices = faces + face_centers = vertices[face_indices].mean(axis=1) + data["stl_centers"] = torch.from_numpy(face_centers) + + # Compute face areas (for stl_areas) + # Area of triangle = 0.5 * ||cross(v1-v0, v2-v0)|| + v0 = vertices[face_indices[:, 0]] + v1 = vertices[face_indices[:, 1]] + v2 = vertices[face_indices[:, 2]] + cross_prod = np.cross(v1 - v0, v2 - v0) + areas = 0.5 * np.linalg.norm(cross_prod, axis=1) + data["stl_areas"] = torch.from_numpy(areas) + + return data + + def _read_vtp_data(self, vtp_path: Path) -> dict[str, torch.Tensor]: + """Read data from a VTP file.""" + # VTP reading is not yet implemented in the original cae_dataset.py + # Placeholder for future implementation + raise NotImplementedError( + "VTP file reading is not yet implemented. " + "This will be added in a future update." + ) + + def _load_sample(self, index: int) -> dict[str, torch.Tensor]: + """Load a single sample from a VTK directory.""" + directory = self._directories[index] + + result = {} + + # Determine which file types to read based on requested keys + need_stl = self.keys_to_read is None or any( + key in self._stl_keys for key in self.keys_to_read + ) + need_vtp = self.keys_to_read is not None and any( + key in self._vtp_keys for key in self.keys_to_read + ) + need_vtu = self.keys_to_read is not None and any( + key in self._vtu_keys for key in self.keys_to_read + ) + + # Read STL data if needed + if need_stl: + stl_path = self._get_file_by_extension(directory, ".stl") + if stl_path: + stl_data = self._read_stl_data(stl_path) + result.update(stl_data) + + # Read VTP data if needed + if need_vtp: + vtp_path = self._get_file_by_extension(directory, ".vtp") + if vtp_path: + vtp_data = self._read_vtp_data(vtp_path) + result.update(vtp_data) + + # Read VTU data if needed + if need_vtu: + raise NotImplementedError("VTU file reading is not yet implemented.") + + # Filter to requested keys if specified + if self.keys_to_read is not None: + result = {k: v for k, v in result.items() if k in self.keys_to_read} + + return result + + def __len__(self) -> int: + """Return number of samples.""" + return self._length + + def _get_field_names(self) -> list[str]: + """Return field names.""" + if self.keys_to_read is not None: + return self.keys_to_read + + # Load first sample to discover available keys + if len(self) == 0: + return [] + + sample = self._load_sample(0) + return list(sample.keys()) + + def _get_sample_metadata(self, index: int) -> dict[str, Any]: + """Return metadata for a sample including source directory info.""" + return { + "source_file": str(self._directories[index]), + "source_filename": self._directories[index].name, + } + + @property + def _supports_coordinated_subsampling(self) -> bool: + """VTK files don't support coordinated subsampling.""" + return False + + def __repr__(self) -> str: + return f"VTKReader(path={self.path}, len={len(self)}, keys={self.keys_to_read})" diff --git a/physicsnemo/datapipes/core/readers/zarr.py b/physicsnemo/datapipes/core/readers/zarr.py new file mode 100644 index 0000000000..ae113e8a1a --- /dev/null +++ b/physicsnemo/datapipes/core/readers/zarr.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ZarrReader - Read data from Zarr arrays. + +Supports reading from a directory of Zarr groups, one sample per group. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Optional, Union + +import numpy as np +import torch + +try: + import zarr + + HAS_ZARR = True +except ImportError: + HAS_ZARR = False + +from physicsnemo.datapipes.core.readers.base import Reader +from physicsnemo.datapipes.core.registry import register_reader + + +@register_reader() +class ZarrReader(Reader): + """ + Read samples from Zarr groups. + + Zarr is a chunked, compressed array format ideal for large scientific datasets. + Each Zarr group in the directory represents one sample. + + Example: + >>> # Directory with sample_0.zarr, sample_1.zarr, ... + >>> # Each contains arrays like "positions", "features", etc. + >>> reader = ZarrReader("data_dir/", group_pattern="sample_*.zarr") + >>> sample = reader[0] + + >>> # Load only specific fields: + >>> reader = ZarrReader("data_dir/", fields=["positions", "velocity"]) + >>> sample = reader[0] + + >>> # With coordinated subsampling for large arrays: + >>> reader = ZarrReader( + ... "data_dir/", + ... coordinated_subsampling={ + ... "n_points": 50000, + ... "target_keys": ["volume_coords", "volume_fields"], + ... } + ... ) + >>> sample = reader[0] + """ + + def __init__( + self, + path: Union[str, Path], + *, + fields: Optional[list[str]] = None, + default_values: Optional[dict[str, torch.Tensor]] = None, + group_pattern: str = "*.zarr", + pin_memory: bool = False, + include_index_in_metadata: bool = True, + coordinated_subsampling: Optional[dict[str, Any]] = None, + cache_stores: bool = True, + ) -> None: + """ + Initialize the Zarr reader. + + Args: + path: Path to directory containing Zarr groups. + fields: List of array names to load. If None, loads all available + arrays from each group. + default_values: Dictionary mapping field names to default tensors. + If a field in ``fields`` is not found in the file but has an + entry here, the default tensor is used instead of raising an + error. Useful for optional fields. + group_pattern: Glob pattern for finding Zarr groups. + pin_memory: If True, place tensors in pinned memory for faster GPU transfer. + include_index_in_metadata: If True, include sample index in metadata. + coordinated_subsampling: Optional dict to configure coordinated + subsampling. If provided, must contain ``n_points`` (int) and + ``target_keys`` (list of str). + cache_stores: If True (default), cache opened zarr stores to avoid + repeated opening and prevent executor shutdown errors. Set to + False if memory is a concern with many groups. + + Raises: + ImportError: If zarr is not installed. + FileNotFoundError: If path doesn't exist. + ValueError: If no Zarr groups found in directory. + """ + if not HAS_ZARR: + raise ImportError( + "zarr is required for ZarrReader. Install with: pip install zarr" + ) + + super().__init__( + pin_memory=pin_memory, + include_index_in_metadata=include_index_in_metadata, + coordinated_subsampling=coordinated_subsampling, + ) + + self.path = Path(path).expanduser().resolve() + self._user_fields = fields + self.default_values = default_values or {} + self.group_pattern = group_pattern + self._cache_stores = cache_stores + self._cached_stores: dict[Path, Any] = {} # Cache for opened zarr stores + + if not self.path.exists(): + raise FileNotFoundError(f"Path not found: {self.path}") + + if not self.path.is_dir(): + raise ValueError( + f"Path must be a directory containing Zarr groups: {self.path}" + ) + + # Detect mode: single group or directory of groups + self._single_group_mode = self._is_zarr_group(self.path) + + if self._single_group_mode: + # Single Zarr group - samples indexed along first dimension + self._groups = [self.path] + root = zarr.open(self.path, mode="r") + + if isinstance(root, zarr.Array): + raise ValueError( + f"Expected Zarr group with named arrays, got single array at " + f"{self.path}. Path should be a Zarr group containing named arrays." + ) + + self._available_fields = list(root.array_keys()) + + # Get length from first array's first dimension + if not self._available_fields: + raise ValueError(f"Zarr group {self.path} contains no arrays") + + first_array = root[self._available_fields[0]] + self._length = first_array.shape[0] + else: + # Directory containing multiple Zarr groups + self._groups = sorted( + [p for p in self.path.glob(group_pattern) if self._is_zarr_group(p)] + ) + + if not self._groups: + raise ValueError( + f"No Zarr groups matching '{group_pattern}' found in {self.path}" + ) + + self._length = len(self._groups) + + # Discover available fields from first group + root = zarr.open(self._groups[0], mode="r") + if isinstance(root, zarr.Array): + raise ValueError( + f"Expected Zarr group with named arrays, got single array at " + f"{self._groups[0]}. Each sample should be a Zarr group containing " + f"named arrays." + ) + self._available_fields = list(root.array_keys()) + + @property + def fields(self) -> list[str]: + """Fields that will be loaded (user-specified or all available).""" + if self._user_fields is not None: + return self._user_fields + return self._available_fields + + def _open_zarr_store(self, path: Path) -> Any: + """ + Open a zarr store, using cache if enabled. + + This prevents the "cannot schedule new futures after shutdown" error + by reusing opened stores instead of repeatedly calling zarr.open(). + + Args: + path: Path to the zarr group. + + Returns: + Opened zarr group. + """ + if self._cache_stores: + if path not in self._cached_stores: + self._cached_stores[path] = zarr.open(path, mode="r") + return self._cached_stores[path] + else: + return zarr.open(path, mode="r") + + def _is_zarr_group(self, path: Path) -> bool: + """ + Check if a path is a Zarr group. + + A Zarr group is identified by the presence of a zarr.json file (v3) + or .zgroup file (v2). + """ + return (path / "zarr.json").exists() or (path / ".zgroup").exists() + + def _select_random_sections_from_slice( + self, + slice_start: int, + slice_stop: int, + n_points: int, + ) -> slice: + """ + Select a random contiguous slice from a range. + + Args: + slice_start: Start index of the available range. + slice_stop: Stop index of the available range (exclusive). + n_points: Number of points to sample. + + Returns: + A slice object representing the random contiguous section. + + Raises: + ValueError: If the range is smaller than n_points. + """ + total_points = slice_stop - slice_start + + if total_points < n_points: + raise ValueError( + f"Slice size {total_points} is less than the number of points " + f"{n_points} requested for subsampling" + ) + + start = np.random.randint(slice_start, slice_stop - n_points + 1) + return slice(start, start + n_points) + + def _load_sample(self, index: int) -> dict[str, torch.Tensor]: + """Load a single sample from a Zarr group.""" + if self._single_group_mode: + # Single group: index into first dimension of each array + group_path = self._groups[0] + root = self._open_zarr_store(group_path) + else: + # Directory mode: each group is one sample + group_path = self._groups[index] + root = self._open_zarr_store(group_path) + + data = {} + fields_to_load = self.fields + + # Check for missing required fields + available = set(root.array_keys()) + required_fields = set(fields_to_load) - set(self.default_values.keys()) + missing_fields = required_fields - available + if missing_fields: + raise KeyError( + f"Required fields {missing_fields} not found in {group_path}. " + f"Available: {list(available)}" + ) + + # Determine subsample slice if coordinated subsampling is enabled + subsample_slice = None + target_keys_set = set() + if self._coordinated_subsampling_config is not None: + n_points = self._coordinated_subsampling_config["n_points"] + target_keys_set = set(self._coordinated_subsampling_config["target_keys"]) + + # Find slice from first available target key + for field in target_keys_set: + if field in root: + if self._single_group_mode: + # In single group mode, subsample along dimensions after the first + array_shape = root[field].shape[1] + else: + array_shape = root[field].shape[0] + subsample_slice = self._select_random_sections_from_slice( + 0, array_shape, n_points + ) + break + + # Load each field + for field in fields_to_load: + if field in root: + if self._single_group_mode: + # Single group mode: index into first dimension + if subsample_slice is not None and field in target_keys_set: + # Apply subsampling on dimensions after the first + data[field] = torch.from_numpy( + root[field][index, subsample_slice] + ) + else: + data[field] = torch.from_numpy(root[field][index]) + else: + # Directory mode: load entire array or subsample + if subsample_slice is not None and field in target_keys_set: + data[field] = torch.from_numpy(root[field][subsample_slice]) + else: + data[field] = torch.from_numpy(root[field][:]) + + elif field in self.default_values: + data[field] = self.default_values[field].clone() + + return data + + def __len__(self) -> int: + """Return number of samples.""" + return self._length + + def _get_field_names(self) -> list[str]: + """Return field names that will be loaded.""" + return self.fields + + def _get_sample_metadata(self, index: int) -> dict[str, Any]: + """Return metadata for a sample including source info.""" + if self._single_group_mode: + return { + "source_file": str(self._groups[0]), + "source_filename": self._groups[0].name, + "sample_index": index, + } + else: + return { + "source_file": str(self._groups[index]), + "source_filename": self._groups[index].name, + } + + @property + def _supports_coordinated_subsampling(self) -> bool: + """Zarr reader supports coordinated subsampling.""" + return True + + def close(self) -> None: + """Close resources and cached zarr stores.""" + # Clear cached stores to allow garbage collection + # This helps prevent executor shutdown issues + self._cached_stores.clear() + super().close() + + def __repr__(self) -> str: + subsample_info = "" + if self._coordinated_subsampling_config is not None: + cfg = self._coordinated_subsampling_config + subsample_info = f", subsampling={cfg['n_points']} points" + + return ( + f"ZarrReader(" + f"path={self.path}, " + f"len={len(self)}, " + f"fields={self.fields}, " + f"cache_stores={self._cache_stores}" + f"{subsample_info})" + ) diff --git a/physicsnemo/datapipes/core/registry.py b/physicsnemo/datapipes/core/registry.py new file mode 100644 index 0000000000..ecca0bda1f --- /dev/null +++ b/physicsnemo/datapipes/core/registry.py @@ -0,0 +1,197 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Registry for datapipe components. + +Provides registries for transforms and readers, enabling: +- Short aliases in Hydra configuration +- Component discovery and introspection +- Consistent instantiation patterns + +Example usage: + >>> from physicsnemo.datapipes.core.registry import TRANSFORM_REGISTRY + >>> + >>> @TRANSFORM_REGISTRY.register() + ... class MyTransform(Transform): + ... pass + >>> + >>> # Get registered component by name + >>> cls = TRANSFORM_REGISTRY.get("MyTransform") + >>> + >>> # List all registered components + >>> print(TRANSFORM_REGISTRY.list()) +""" + +from __future__ import annotations + +from typing import Callable, Type, TypeVar + +T = TypeVar("T") + + +class ComponentRegistry: + """ + Registry for datapipe components with short aliases. + + A registry allows components (transforms, readers) to be registered + with a name and later retrieved by that name. This enables: + + - Hydra configuration with short names instead of full import paths + - Runtime discovery of available components + - Validation that a component exists + + Example: + >>> registry = ComponentRegistry("transforms") + >>> + >>> @registry.register() + ... class Normalize(Transform): + ... pass + >>> + >>> @registry.register("norm") # Custom alias + ... class Normalize(Transform): + ... pass + >>> + >>> # Retrieve by name + >>> Normalize = registry.get("Normalize") + >>> Normalize = registry.get("norm") + """ + + def __init__(self, name: str) -> None: + """ + Initialize the registry. + + Args: + name: Human-readable name for this registry (e.g., "transforms"). + """ + self.name = name + self._registry: dict[str, Type] = {} + + def register(self, name: str | None = None) -> Callable[[Type[T]], Type[T]]: + """ + Decorator to register a component class. + + Args: + name: Optional name to register under. If None, uses the class name. + + Returns: + Decorator function that registers the class. + + Example: + >>> @registry.register() + ... class MyTransform(Transform): + ... pass + >>> + >>> @registry.register("custom_name") + ... class AnotherTransform(Transform): + ... pass + """ + + def decorator(cls: Type[T]) -> Type[T]: + key = name if name is not None else cls.__name__ + if key in self._registry: + raise ValueError( + f"Component '{key}' is already registered in {self.name} registry. " + f"Existing: {self._registry[key]}, New: {cls}" + ) + self._registry[key] = cls + return cls + + return decorator + + def get(self, name: str) -> Type: + """ + Get a registered component by name. + + Args: + name: The registered name of the component. + + Returns: + The registered class. + + Raises: + KeyError: If the name is not registered. + """ + if name not in self._registry: + available = ", ".join(sorted(self._registry.keys())) + raise KeyError( + f"Component '{name}' not found in {self.name} registry. " + f"Available: {available or '(none)'}" + ) + return self._registry[name] + + def list(self) -> list[str]: + """ + List all registered component names. + + Returns: + Sorted list of registered names. + """ + return sorted(self._registry.keys()) + + def __contains__(self, name: str) -> bool: + """Check if a name is registered.""" + return name in self._registry + + def __len__(self) -> int: + """Return the number of registered components.""" + return len(self._registry) + + def __repr__(self) -> str: + return f"ComponentRegistry({self.name!r}, count={len(self)})" + + +# Global registries for transforms and readers +TRANSFORM_REGISTRY = ComponentRegistry("transforms") +READER_REGISTRY = ComponentRegistry("readers") + + +def register_transform(name: str | None = None) -> Callable[[Type[T]], Type[T]]: + """ + Decorator to register a transform class. + + This is a convenience wrapper around TRANSFORM_REGISTRY.register(). + + Args: + name: Optional name to register under. If None, uses the class name. + + Example: + >>> from physicsnemo.datapipes.core.registry import register_transform + >>> + >>> @register_transform() + ... class MyTransform(Transform): + ... pass + """ + return TRANSFORM_REGISTRY.register(name) + + +def register_reader(name: str | None = None) -> Callable[[Type[T]], Type[T]]: + """ + Decorator to register a reader class. + + This is a convenience wrapper around READER_REGISTRY.register(). + + Args: + name: Optional name to register under. If None, uses the class name. + + Example: + >>> from physicsnemo.datapipes.core.registry import register_reader + >>> + >>> @register_reader() + ... class MyReader(Reader): + ... pass + """ + return READER_REGISTRY.register(name) diff --git a/physicsnemo/datapipes/core/transforms/__init__.py b/physicsnemo/datapipes/core/transforms/__init__.py new file mode 100644 index 0000000000..34e0509a2d --- /dev/null +++ b/physicsnemo/datapipes/core/transforms/__init__.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Transforms module - Operations on Samples. + +Transforms are composable operations that take a Sample and return a modified Sample. +They are designed for GPU preprocessing +""" + +from physicsnemo.datapipes.core.transforms.base import Transform +from physicsnemo.datapipes.core.transforms.compose import Compose + +# NOTE: Downsample and ToDevice transforms are not yet implemented +# from physicsnemo.datapipes.core.transforms.downsample import Downsample +# from physicsnemo.datapipes.core.transforms.to_device import ToDevice, cpu, cuda +from physicsnemo.datapipes.core.transforms.field_processing import ( + BroadcastGlobalFeatures, +) +from physicsnemo.datapipes.core.transforms.field_slice import FieldSlice +from physicsnemo.datapipes.core.transforms.geometric import ( + ComputeNormals, + ComputeSDF, + ReScale, + Translate, +) +from physicsnemo.datapipes.core.transforms.normalize import Normalize +from physicsnemo.datapipes.core.transforms.spatial import ( + BoundingBoxFilter, + CenterOfMass, + CreateGrid, + KNNNeighbors, +) +from physicsnemo.datapipes.core.transforms.subsample import ( + SubsamplePoints, + poisson_sample_indices_fixed, + shuffle_array, +) + +__all__ = [ + # Base + "Transform", + "Compose", + # Existing transforms + "Normalize", + # "Downsample", # Not yet implemented + # "ToDevice", # Not yet implemented + # "cuda", # Not yet implemented + # "cpu", # Not yet implemented + # Subsampling + "SubsamplePoints", + "poisson_sample_indices_fixed", + "shuffle_array", + # Geometric + "ComputeSDF", + "ComputeNormals", + "Translate", + "ReScale", + # Field processing + "FieldSlice", + "BroadcastGlobalFeatures", + # Spatial + "BoundingBoxFilter", + "CreateGrid", + "KNNNeighbors", + "CenterOfMass", +] diff --git a/physicsnemo/datapipes/core/transforms/base.py b/physicsnemo/datapipes/core/transforms/base.py new file mode 100644 index 0000000000..58d4bb9009 --- /dev/null +++ b/physicsnemo/datapipes/core/transforms/base.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Transform base class - The foundation for all data transformations. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch +from tensordict import TensorDict + + +class Transform(ABC): + """ + Abstract base class for all transforms. + + Transforms operate on a TensorDict and return a modified TensorDict. + They are designed to run on GPU tensors for maximum performance. + Metadata is not passed to transforms (handled separately by Dataset/DataLoader). + + Subclasses must implement: + - __call__(data: TensorDict) -> TensorDict + + Optionally override: + - extra_repr() -> str: For custom repr output + - state_dict() -> dict: For serialization + - load_state_dict(state_dict: dict): For deserialization + + Example: + >>> class MyTransform(Transform): + ... def __init__(self, scale: float): + ... super().__init__() + ... self.scale = scale + ... + ... def __call__(self, data: TensorDict) -> TensorDict: + ... # Apply transformation to all tensors + ... return data.apply(lambda x: x * self.scale) + """ + + def __init__(self) -> None: + """Initialize the transform.""" + self._device: Optional[torch.device] = None + + @abstractmethod + def __call__(self, data: TensorDict) -> TensorDict: + """ + Apply the transform to a TensorDict. + + Args: + data: Input TensorDict to transform. + + Returns: + Transformed TensorDict. + """ + raise NotImplementedError + + def to(self, device: torch.device | str) -> Transform: + """ + Move any internal tensors to the specified device. + + Override this method if your transform has internal tensor state. + + Args: + device: Target device. + + Returns: + Self for chaining. + """ + self._device = torch.device(device) if isinstance(device, str) else device + return self + + @property + def device(self) -> torch.device | None: + """The device this transform operates on.""" + return self._device + + def extra_repr(self) -> str: + """ + Return extra information for repr. + + Override this to add transform-specific info to the repr. + """ + return "" + + def state_dict(self) -> dict[str, Any]: + """ + Return a dictionary containing the transform's state. + + Override this for transforms with learnable or configurable state. + """ + return {} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """ + Load state from a state dictionary. + + Override this to restore transform state. + """ + pass + + def __repr__(self) -> str: + extra = self.extra_repr() + if extra: + return f"{self.__class__.__name__}({extra})" + return f"{self.__class__.__name__}()" diff --git a/physicsnemo/datapipes/core/transforms/compose.py b/physicsnemo/datapipes/core/transforms/compose.py new file mode 100644 index 0000000000..8c9e3570c9 --- /dev/null +++ b/physicsnemo/datapipes/core/transforms/compose.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Compose - Chain multiple transforms into a single transform. +""" + +from __future__ import annotations + +from typing import Any, Iterator, Sequence + +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes.core.registry import register_transform +from physicsnemo.datapipes.core.transforms.base import Transform + + +@register_transform() +class Compose(Transform): + """ + Compose multiple transforms into a sequential pipeline. + + Applies transforms in order, passing the output of each as input to the next. + + Example: + >>> pipeline = Compose([ + ... Normalize(["pressure"], means={"pressure": 0.0}, stds={"pressure": 1.0}), + ... Downsample(["pressure"], n=1000), + ... ]) + >>> transformed = pipeline(sample) + """ + + def __init__(self, transforms: Sequence[Transform]) -> None: + """ + Initialize the composition. + + Args: + transforms: Sequence of transforms to apply in order. + + Raises: + TypeError: If any element is not a Transform. + ValueError: If transforms is empty. + """ + super().__init__() + + if not transforms: + raise ValueError("transforms cannot be empty") + + for i, t in enumerate(transforms): + if not isinstance(t, Transform): + raise TypeError( + f"All elements must be Transform instances, " + f"got {type(t).__name__} at index {i}" + ) + + self.transforms: list[Transform] = list(transforms) + + def __call__(self, data: TensorDict) -> TensorDict: + """Apply all transforms in sequence.""" + for transform in self.transforms: + data = transform(data) + return data + + def to(self, device: torch.device | str) -> Compose: + """Move all transforms to the specified device.""" + super().to(device) + for transform in self.transforms: + transform.to(device) + return self + + def __getitem__(self, index: int) -> Transform: + """Get a transform by index.""" + return self.transforms[index] + + def __len__(self) -> int: + """Return number of transforms.""" + return len(self.transforms) + + def __iter__(self) -> Iterator[Transform]: + """Iterate over transforms.""" + return iter(self.transforms) + + def append(self, transform: Transform) -> None: + """Append a transform to the pipeline.""" + if not isinstance(transform, Transform): + raise TypeError(f"Expected Transform, got {type(transform).__name__}") + self.transforms.append(transform) + + def state_dict(self) -> dict[str, Any]: + """Return state of all transforms.""" + return { + "transforms": [t.state_dict() for t in self.transforms], + "transform_types": [type(t).__name__ for t in self.transforms], + } + + def extra_repr(self) -> str: + lines = [] + for i, t in enumerate(self.transforms): + lines.append(f" ({i}): {t}") + return "\n" + "\n".join(lines) + "\n" diff --git a/physicsnemo/datapipes/core/transforms/field_processing.py b/physicsnemo/datapipes/core/transforms/field_processing.py new file mode 100644 index 0000000000..9644072027 --- /dev/null +++ b/physicsnemo/datapipes/core/transforms/field_processing.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Field processing transforms for feature engineering. + +Provides transforms for broadcasting global features to local points. +""" + +from __future__ import annotations + +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes.core.registry import register_transform +from physicsnemo.datapipes.core.transforms.base import Transform + + +@register_transform() +class BroadcastGlobalFeatures(Transform): + r""" + Broadcast global scalar/vector features to all spatial points. + + Replicates global parameters (e.g., density, velocity) to match the number + of spatial points, enabling concatenation with local features. + + Parameters + ---------- + input_keys : list[str] + List of global feature keys to broadcast. + n_points_key : str + Key of a tensor whose first dimension gives the number of points to broadcast to. + output_key : str + Key to store the broadcasted features. + + Examples + -------- + >>> transform = BroadcastGlobalFeatures( + ... input_keys=["air_density", "stream_velocity"], + ... n_points_key="embeddings", + ... output_key="fx" + ... ) + >>> data = TensorDict({ + ... "air_density": torch.tensor(1.225), + ... "stream_velocity": torch.tensor(30.0), + ... "embeddings": torch.randn(10000, 7) + ... }) + >>> result = transform(data) + >>> print(result["fx"].shape) + torch.Size([10000, 2]) + """ + + def __init__( + self, + input_keys: list[str], + n_points_key: str, + output_key: str, + ) -> None: + """Initialize the broadcast transform.""" + super().__init__() + self.input_keys = input_keys + self.n_points_key = n_points_key + self.output_key = output_key + + def __call__(self, data: TensorDict) -> TensorDict: + """ + Broadcast global features to match spatial dimensions. + + Parameters + ---------- + data : TensorDict + Input TensorDict containing global features and reference tensor. + + Returns + ------- + TensorDict + TensorDict with broadcasted features added. + + Raises + ------ + KeyError + If required keys are not found in the TensorDict. + """ + if self.n_points_key not in data.keys(): + raise KeyError(f"Reference key '{self.n_points_key}' not found") + + n_points = data[self.n_points_key].shape[0] + + # Collect features + features = [] + for key in self.input_keys: + if key not in data.keys(): + raise KeyError(f"Feature key '{key}' not found") + + feature = data[key] + + # Ensure scalar features are expanded + if feature.ndim == 0: + feature = feature.unsqueeze(0) + + features.append(feature) + + # Stack features + fx = torch.stack(features, dim=-1) + + # Broadcast to match number of points + fx = fx.broadcast_to(n_points, fx.shape[-1]) + + return data.update({self.output_key: fx}) + + def __repr__(self) -> str: + return ( + f"BroadcastGlobalFeatures(input_keys={self.input_keys}, " + f"output_key={self.output_key})" + ) diff --git a/physicsnemo/datapipes/core/transforms/field_slice.py b/physicsnemo/datapipes/core/transforms/field_slice.py new file mode 100644 index 0000000000..64cd7a0479 --- /dev/null +++ b/physicsnemo/datapipes/core/transforms/field_slice.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FieldSlice - Select specific indices or slices from tensor dimensions. +""" + +from __future__ import annotations + +from typing import Union + +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes.core.registry import register_transform +from physicsnemo.datapipes.core.transforms.base import Transform + +# Type for a single dimension's slice specification +# Can be: list of indices [0, 2, 5], or dict for slice {"start": 0, "stop": 5, "step": 2} +SliceSpec = Union[list[int], dict[str, int]] + + +@register_transform() +class FieldSlice(Transform): + """ + Select specific indices or slices from tensor dimensions. + + This transform allows selecting subsets of data along any dimension of + specified fields. It supports two modes: + + 1. **Index selection**: Provide a list of indices to select + 2. **Slice selection**: Provide start/stop/step as a dict + + Example with index selection: + >>> # Select features 0, 2, 5 from last dimension of "features" field + >>> transform = FieldSlice({ + ... "features": {-1: [0, 2, 5]}, + ... }) + >>> # Input shape: (N, 10) -> Output shape: (N, 3) + + Example with slice selection: + >>> # Select first 5 features using slice notation + >>> transform = FieldSlice({ + ... "features": {-1: {"start": 0, "stop": 5}}, + ... }) + >>> # Input shape: (N, 10) -> Output shape: (N, 5) + + Example with multiple dimensions: + >>> # Slice both dimensions + >>> transform = FieldSlice({ + ... "grid": { + ... 0: [0, 1, 2], # First 3 indices of dim 0 + ... -1: {"stop": 4}, # First 4 of last dim (slice) + ... }, + ... }) + + Hydra configuration example: + .. code-block:: yaml + + _target_: physicsnemo.datapipes.core.transforms.FieldSlice + slicing: + features: + "-1": [0, 2, 5] + velocity: + "-1": + stop: 2 + """ + + def __init__( + self, + slicing: dict[str, dict[int | str, SliceSpec]], + ) -> None: + """ + Initialize the FieldSlice transform. + + Args: + slicing: Dictionary mapping field names to dimension slicing specs. + Format:: + + { + "field_name": { + dim: indices_or_slice, + ... + }, + ... + } + + Where: + - ``dim`` is the dimension index (int, or str for Hydra like "-1") + - ``indices_or_slice`` is either: + - A list of indices: ``[0, 2, 5]`` + - A slice dict: ``{"start": 0, "stop": 5, "step": 1}`` + """ + super().__init__() + self.slicing = slicing + + def __call__(self, data: TensorDict) -> TensorDict: + """ + Apply slicing to the specified fields. + + Args: + data: Input TensorDict. + + Returns: + TensorDict with sliced fields. + + Raises: + KeyError: If a specified field is not in the TensorDict. + """ + updates = {} + + for field_name, dim_specs in self.slicing.items(): + if field_name not in data.keys(): + raise KeyError( + f"Field '{field_name}' not found in data. " + f"Available: {list(data.keys())}" + ) + + tensor = data[field_name] + + for dim_key, spec in dim_specs.items(): + # Handle string keys from Hydra/YAML (e.g., "-1" -> -1) + dim = int(dim_key) + # Normalize negative dimension + if dim < 0: + dim = tensor.ndim + dim + + tensor = self._apply_slice(tensor, dim, spec) + + updates[field_name] = tensor + + return data.update(updates) + + def _apply_slice( + self, + tensor: torch.Tensor, + dim: int, + spec: SliceSpec, + ) -> torch.Tensor: + """ + Apply a single slice specification to a tensor. + + Args: + tensor: Input tensor. + dim: Dimension to slice (normalized to positive). + spec: Slice specification (list of indices or slice dict). + + Returns: + Sliced tensor. + """ + if isinstance(spec, list): + # Index selection: [0, 2, 5] + indices = torch.tensor(spec, dtype=torch.long, device=tensor.device) + return torch.index_select(tensor, dim, indices) + elif isinstance(spec, dict): + # Slice selection: {"start": 0, "stop": 5, "step": 1} + start = spec.get("start", None) + stop = spec.get("stop", None) + step = spec.get("step", None) + + # Build slice object + slc = slice(start, stop, step) + + # Apply slice using narrow or direct indexing + # We need to build the full index tuple + idx = [slice(None)] * tensor.ndim + idx[dim] = slc + return tensor[tuple(idx)] + else: + raise TypeError( + f"Invalid slice spec type: {type(spec)}. " + f"Expected list of indices or dict with start/stop/step." + ) + + def extra_repr(self) -> str: + return f"slicing={self.slicing}" diff --git a/physicsnemo/datapipes/core/transforms/geometric.py b/physicsnemo/datapipes/core/transforms/geometric.py new file mode 100644 index 0000000000..c8ad612433 --- /dev/null +++ b/physicsnemo/datapipes/core/transforms/geometric.py @@ -0,0 +1,358 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Geometric transforms for spatial data processing. + +Provides transforms for computing signed distance fields, normals, +and applying spatial invariances (translation, scaling). +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes.core.registry import register_transform +from physicsnemo.datapipes.core.transforms.base import Transform +from physicsnemo.nn.sdf import signed_distance_field + + +@register_transform() +class ComputeSDF(Transform): + r""" + Compute signed distance field from a mesh. + + Computes the signed distance from query points to the nearest point on + a triangular mesh surface. Optionally returns the closest points on the + mesh surface for each query point. + + Example: + >>> transform = ComputeSDF( + ... input_keys=["volume_mesh_centers"], + ... output_key="sdf_nodes", + ... mesh_coords_key="stl_coordinates", + ... mesh_faces_key="stl_faces", + ... closest_points_key="closest_points" + ... ) + >>> sample = Tensordict({ + ... "volume_mesh_centers": torch.randn(10000, 3), + ... "stl_coordinates": torch.randn(5000, 3), + ... "stl_faces": torch.randint(0, 5000, (10000,)) + ... }) + >>> result = transform(sample) + >>> print(result["sdf_nodes"].shape) + torch.Size([10000, 1]) + """ + + def __init__( + self, + input_keys: list[str], + output_key: str, + mesh_coords_key: str, + mesh_faces_key: str, + *, + use_winding_number: bool = True, + closest_points_key: Optional[str] = None, + ) -> None: + """ + Initialize the SDF computation transform. + + Args: + input_keys: List of keys containing query points to compute SDF for. + Each tensor should have shape :math:`(N, 3)`. + output_key: Key to store the computed SDF values. + mesh_coords_key: Key for mesh vertex coordinates, shape :math:`(M, 3)`. + mesh_faces_key: Key for mesh face indices (flattened), shape :math:`(F*3,)`. + use_winding_number: If True, use winding number for sign determination. + closest_points_key: Optional key to store closest points on mesh. + """ + super().__init__() + self.input_keys = input_keys + self.output_key = output_key + self.mesh_coords_key = mesh_coords_key + self.mesh_faces_key = mesh_faces_key + self.use_winding_number = use_winding_number + self.closest_points_key = closest_points_key + + def __call__(self, data: TensorDict) -> TensorDict: + """Compute SDF for the sample.""" + # Get mesh data + if self.mesh_coords_key not in data: + raise KeyError(f"Mesh coordinates key '{self.mesh_coords_key}' not found") + if self.mesh_faces_key not in data: + raise KeyError(f"Mesh faces key '{self.mesh_faces_key}' not found") + + mesh_coords = data[self.mesh_coords_key] + mesh_faces = data[self.mesh_faces_key].to(torch.int32) + + updates = {} + + # Compute SDF for each input key + for key in self.input_keys: + if key not in data: + raise KeyError(f"Input key '{key}' not found") + + query_points = data[key] + + # Compute SDF and closest points + sdf, closest_points = signed_distance_field( + mesh_coords, + mesh_faces, + query_points, + use_sign_winding_number=self.use_winding_number, + ) + + # Store SDF with output key (add suffix if multiple inputs) + if len(self.input_keys) == 1: + updates[self.output_key] = sdf.reshape(-1, 1) + if self.closest_points_key is not None: + updates[self.closest_points_key] = closest_points + else: + suffix = f"_{key}" + updates[f"{self.output_key}{suffix}"] = sdf.reshape(-1, 1) + if self.closest_points_key is not None: + updates[f"{self.closest_points_key}{suffix}"] = closest_points + + return data.update(updates) + + def __repr__(self) -> str: + return f"ComputeSDF(input_keys={self.input_keys}, output_key={self.output_key})" + + +@register_transform() +class ComputeNormals(Transform): + r""" + Compute normal vectors from closest points. + + Computes normalized direction vectors from query points to their closest + points on a surface. Handles zero-distance edge cases by falling back to + center of mass direction. + + Example: + >>> transform = ComputeNormals( + ... positions_key="volume_mesh_centers", + ... closest_points_key="closest_points", + ... center_of_mass_key="center_of_mass", + ... output_key="volume_normals" + ... ) + """ + + def __init__( + self, + positions_key: str, + closest_points_key: str, + center_of_mass_key: str, + output_key: str, + *, + handle_zero_distance: bool = True, + ) -> None: + """ + Initialize the normal computation transform. + + Args: + positions_key: Key for position tensor, shape :math:`(N, 3)`. + closest_points_key: Key for closest points tensor, shape :math:`(N, 3)`. + center_of_mass_key: Key for center of mass, shape :math:`(1, 3)` or :math:`(3,)`. + output_key: Key to store computed normals. + handle_zero_distance: If True, use center_of_mass fallback for zero distances. + """ + super().__init__() + self.positions_key = positions_key + self.closest_points_key = closest_points_key + self.center_of_mass_key = center_of_mass_key + self.output_key = output_key + self.handle_zero_distance = handle_zero_distance + + def __call__(self, data: TensorDict) -> TensorDict: + """Compute normals for the sample.""" + positions = data[self.positions_key] + closest_points = data[self.closest_points_key] + center_of_mass = data[self.center_of_mass_key] + + # Ensure center_of_mass has shape (1, 3) + if center_of_mass.ndim == 1: + center_of_mass = center_of_mass.unsqueeze(0) + + # Compute initial normals + normals = positions - closest_points + + if self.handle_zero_distance: + # Handle zero-distance points (on or very close to surface) + distance_to_closest = torch.norm(normals, dim=-1) + null_points = distance_to_closest < 1e-6 + + # For null points, use direction from center of mass + if null_points.any(): + normals[null_points] = positions[null_points] - center_of_mass + + # Normalize + norm = torch.norm(normals, dim=-1, keepdim=True) + 1e-6 + normals = normals / norm + + return data.update({self.output_key: normals}) + + def __repr__(self) -> str: + return ( + f"ComputeNormals(positions_key={self.positions_key}, " + f"output_key={self.output_key})" + ) + + +class Translate(Transform): + r""" + Apply a translation by subtracting a center point. + + Subtracts a reference point (typically center of mass) from position-like + tensors to make the representation translation invariant. + + Example: + >>> transform = TranslationInvariance( + ... input_keys=["volume_mesh_centers", "surface_mesh_centers"], + ... center_key_or_value="center_of_mass" + ... ) + """ + + def __init__( + self, + input_keys: list[str], + center_key_or_value: Union[str, torch.Tensor], + ) -> None: + """ + Initialize the translation invariance transform. + + Args: + input_keys: List of position tensor keys to translate. + center_key_or_value: Either a key name (str) for a tensor in the sample, + or a fixed tensor value to subtract. + """ + super().__init__() + self.input_keys = input_keys + self.center_key_or_value = center_key_or_value + self.is_key = isinstance(center_key_or_value, str) + + def __call__(self, data: TensorDict) -> TensorDict: + """Apply translation to the sample.""" + # Get center value + if isinstance(self.center_key_or_value, str): + if self.center_key_or_value not in data: + raise KeyError(f"Center key '{self.center_key_or_value}' not found") + center = data[self.center_key_or_value] + else: + if not isinstance(self.center_key_or_value, torch.Tensor): + raise TypeError( + f"center_key_or_value should be torch.Tensor but got {type(self.center_key_or_value)}" + ) + center = self.center_key_or_value + # Move to same device as data if needed + if data.device is not None and center.device != data.device: + center = center.to(data.device) + + # Ensure center has shape (1, 3) or (1, D) + if center.ndim == 1: + center = center.unsqueeze(0) + + # Apply translation to all keys + updates = {} + for key in self.input_keys: + if key in data: + updates[key] = data[key] - center + + return data.update(updates) + + def to(self, device: Union[torch.device, str]) -> "Translate": + """Move center tensor to the specified device (if not a key reference).""" + super().to(device) + if not self.is_key: + if not isinstance(self.center_key_or_value, torch.Tensor): + raise TypeError( + f"center_key_or_value should be torch.Tensor but got {type(self.center_key_or_value)}" + ) + device = torch.device(device) if isinstance(device, str) else device + self.center_key_or_value = self.center_key_or_value.to(device) + return self + + def __repr__(self) -> str: + return ( + f"TranslationInvariance(input_keys={self.input_keys}, " + f"center={self.center_key_or_value})" + ) + + +class ReScale(Transform): + r""" + Apply a scale factor by dividing by a reference scale. + + Divides position tensors by a reference scale to make the representation + scale invariant. + + Example: + >>> transform = ReScale( + ... input_keys=["volume_mesh_centers", "geometry_coordinates"], + ... reference_scale=torch.tensor([[1.0, 1.0, 1.0]]) + ... ) + """ + + def __init__( + self, + input_keys: list[str], + reference_scale: torch.Tensor, + ) -> None: + """ + Initialize the scale invariance transform. + + Args: + input_keys: List of position tensor keys to scale. + reference_scale: Tensor to divide by, shape :math:`(1, D)` or :math:`(D,)`. + """ + super().__init__() + self.input_keys = input_keys + self.reference_scale = reference_scale + + def __call__(self, data: TensorDict) -> TensorDict: + """Apply scaling to the data.""" + scale = self.reference_scale + + # Ensure scale has batch dimension + if scale.ndim == 1: + scale = scale.unsqueeze(0) + + # Move scale to same device as data if needed + if data.device is not None and scale.device != data.device: + scale = scale.to(data.device) + + # Apply scaling to all keys + updates = {} + for key in self.input_keys: + if key in data: + updates[key] = data[key] / scale + + return data.update(updates) + + def to(self, device: Union[torch.device, str]) -> "ReScale": + """Move reference scale tensor to the specified device.""" + super().to(device) + device = torch.device(device) if isinstance(device, str) else device + self.reference_scale = self.reference_scale.to(device) + return self + + def __repr__(self) -> str: + return ( + f"ScaleInvariance(input_keys={self.input_keys}, " + f"scale_shape={self.reference_scale.shape})" + ) diff --git a/physicsnemo/datapipes/core/transforms/normalize.py b/physicsnemo/datapipes/core/transforms/normalize.py new file mode 100644 index 0000000000..062acc47ae --- /dev/null +++ b/physicsnemo/datapipes/core/transforms/normalize.py @@ -0,0 +1,390 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Normalize - Standardize tensor values by mean and standard deviation or min-max scaling. +""" + +from __future__ import annotations + +import warnings +from pathlib import Path +from typing import Any, Literal, Optional + +import numpy as np +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes.core.registry import register_transform +from physicsnemo.datapipes.core.transforms.base import Transform + + +@register_transform() +class Normalize(Transform): + """ + Normalize specified fields using mean-std or min-max scaling. + + Supports two normalization methods: + - ``mean_std``: Applies (x - mean) / std for each specified field + - ``min_max``: Applies (x - center) / half_range, normalizing to [-1, 1] + where center = (max + min) / 2 and half_range = (max - min) / 2 + + Parameters can be provided directly or loaded from a ``.npz`` file. + + Example (mean-std scaling): + >>> norm = Normalize( + ... input_keys=["pressure", "velocity"], + ... method="mean_std", + ... means={"pressure": 101325.0, "velocity": 0.0}, + ... stds={"pressure": 1000.0, "velocity": 10.0}, + ... ) + >>> normalized = norm(sample) + + Example (min-max scaling): + >>> norm = Normalize( + ... input_keys=["pressure", "velocity"], + ... method="min_max", + ... mins={"pressure": 100000.0, "velocity": -50.0}, + ... maxs={"pressure": 110000.0, "velocity": 50.0}, + ... ) + >>> normalized = norm(sample) + + Example (loading from file): + >>> norm = Normalize( + ... input_keys=["pressure", "velocity"], + ... method="mean_std", + ... stats_file="normalization_stats.npz", + ... ) + >>> normalized = norm(sample) + """ + + def __init__( + self, + input_keys: list[str], + means: Optional[dict[str, float | torch.Tensor] | float | torch.Tensor] = None, + stds: Optional[dict[str, float | torch.Tensor] | float | torch.Tensor] = None, + *, + method: Optional[Literal["mean_std", "min_max"]] = None, + mins: Optional[dict[str, float | torch.Tensor] | float | torch.Tensor] = None, + maxs: Optional[dict[str, float | torch.Tensor] | float | torch.Tensor] = None, + stats_file: Optional[str | Path] = None, + eps: float = 1e-8, + ) -> None: + """ + Initialize the normalizer. + + Args: + input_keys: List of field names to normalize. + means: Mean values for mean_std method. Either a dict mapping field names + to values, or a single value applied to all fields. Deprecated if + ``method`` is not specified. + stds: Standard deviation values for mean_std method. Same format as means. + method: Normalization method - either ``"mean_std"`` or ``"min_max"``. + mins: Minimum values for min_max method. Same format as means. + maxs: Maximum values for min_max method. Same format as means. + stats_file: Path to ``.npz`` file containing normalization statistics. + File should contain per-field dicts with keys like 'mean', + 'std', 'min', 'max'. + eps: Small value added to prevent division by zero. + + Raises: + ValueError: If input_keys is empty, method is invalid, or required + parameters are missing. + """ + super().__init__() + + if not input_keys: + raise ValueError("input_keys cannot be empty") + + self.input_keys = list(input_keys) + self.eps = eps + + # Handle backward compatibility: if means/stds provided without method + if means is not None and stds is not None and method is None: + warnings.warn( + "Providing 'means' and 'stds' without 'method' parameter is deprecated. " + "Please specify method='mean_std' explicitly. " + "This will become an error in a future version.", + DeprecationWarning, + stacklevel=2, + ) + method = "mean_std" + + # Validate method + if method not in ["mean_std", "min_max"]: + raise ValueError(f"method must be 'mean_std' or 'min_max', got: {method}") + + self.method = method + + # Load stats from file if provided + if stats_file is not None: + stats = self._load_stats_from_npz(stats_file) + if method == "mean_std": + if means is None: + means = stats.get("means", {}) + if stds is None: + stds = stats.get("stds", {}) + else: # min_max + if mins is None: + mins = stats.get("mins", {}) + if maxs is None: + maxs = stats.get("maxs", {}) + + # Initialize storage based on method + if method == "mean_std": + if means is None or stds is None: + raise ValueError( + "For method='mean_std', both 'means' and 'stds' must be provided " + "either directly or via stats_file" + ) + self._means = self._process_stats_dict(means, "mean") + self._stds = self._process_stats_dict(stds, "std") + self._mins: Optional[dict[str, torch.Tensor]] = None + self._maxs: Optional[dict[str, torch.Tensor]] = None + + else: # min_max + if mins is None or maxs is None: + raise ValueError( + "For method='min_max', both 'mins' and 'maxs' must be provided " + "either directly or via stats_file" + ) + self._mins = self._process_stats_dict(mins, "min") + self._maxs = self._process_stats_dict(maxs, "max") + self._means: Optional[dict[str, torch.Tensor]] = None + self._stds: Optional[dict[str, torch.Tensor]] = None + + def _process_stats_dict( + self, + stats: dict[str, float | torch.Tensor] | float | torch.Tensor, + stat_name: str, + ) -> dict[str, torch.Tensor]: + """Process statistics into dict of tensors for each field.""" + result: dict[str, torch.Tensor] = {} + + if isinstance(stats, dict): + for key in self.input_keys: + if key not in stats: + raise ValueError( + f"{stat_name.capitalize()} not provided for field '{key}'" + ) + val = stats[key] + result[key] = ( + torch.as_tensor(val) if not isinstance(val, torch.Tensor) else val + ) + else: + # Single value for all fields + stat_tensor = ( + torch.as_tensor(stats) if not isinstance(stats, torch.Tensor) else stats + ) + for key in self.input_keys: + result[key] = stat_tensor.clone() + + return result + + def _load_stats_from_npz(self, stats_file: str | Path) -> dict[str, dict]: + """ + Load normalization statistics from .npz file. + + Expected file structure: Dictionary mapping field names to dicts with + keys 'mean', 'std', 'min', 'max' (numpy arrays). + + Args: + stats_file: Path to .npz file. + + Returns: + Dictionary with keys 'means', 'stds', 'mins', 'maxs', each mapping + field names to torch tensors. + + Raises: + FileNotFoundError: If file doesn't exist. + ValueError: If required statistics are missing. + """ + file_path = Path(stats_file) + if not file_path.exists(): + raise FileNotFoundError(f"Statistics file not found: {stats_file}") + + # Load npz file + data = np.load(str(file_path), allow_pickle=True) + + # Initialize output dicts + means_dict = {} + stds_dict = {} + mins_dict = {} + maxs_dict = {} + + # Process each field + for key in self.input_keys: + if key not in data: + raise ValueError( + f"Field '{key}' not found in stats file. " + f"Available fields: {list(data.keys())}" + ) + + field_stats = data[key] + if isinstance(field_stats, np.ndarray) and field_stats.dtype == object: + # It's a dict stored as numpy object + field_stats = field_stats.item() + + # Extract stats if available + if "mean" in field_stats: + means_dict[key] = torch.as_tensor(field_stats["mean"]) + if "std" in field_stats: + stds_dict[key] = torch.as_tensor(field_stats["std"]) + if "min" in field_stats: + mins_dict[key] = torch.as_tensor(field_stats["min"]) + if "max" in field_stats: + maxs_dict[key] = torch.as_tensor(field_stats["max"]) + + return { + "means": means_dict, + "stds": stds_dict, + "mins": mins_dict, + "maxs": maxs_dict, + } + + def __call__(self, data: TensorDict) -> TensorDict: + """ + Normalize the specified fields in the TensorDict. + + Args: + data: Input TensorDict. + + Returns: + TensorDict with normalized fields. + + Raises: + KeyError: If a specified field is not in the TensorDict. + """ + updates = {} + + for key in self.input_keys: + if key not in data.keys(): + raise KeyError( + f"Field '{key}' not found in data. Available: {list(data.keys())}" + ) + + tensor = data[key] + + if self.method == "mean_std": + mean = self._means[key] + std = self._stds[key] + + # Normalize: (x - mean) / std + updates[key] = (tensor - mean) / (std + self.eps) + + else: # min_max + min_val = self._mins[key] + max_val = self._maxs[key] + + # Normalize to [-1, 1]: (x - center) / half_range + center = (max_val + min_val) / 2.0 + half_range = (max_val - min_val) / 2.0 + updates[key] = (tensor - center) / (half_range + self.eps) + + # Update TensorDict with normalized values + return data.update(updates) + + def to(self, device: torch.device | str) -> Normalize: + """Move normalization parameters to the specified device.""" + super().to(device) + device = torch.device(device) if isinstance(device, str) else device + + if self.method == "mean_std": + for key in self.input_keys: + self._means[key] = self._means[key].to(device, non_blocking=True) + self._stds[key] = self._stds[key].to(device, non_blocking=True) + else: # min_max + for key in self.input_keys: + self._mins[key] = self._mins[key].to(device, non_blocking=True) + self._maxs[key] = self._maxs[key].to(device, non_blocking=True) + + return self + + def inverse(self, data: TensorDict) -> TensorDict: + """ + Apply inverse normalization (denormalize). + + For mean_std method: x * std + mean + For min_max method: x * half_range + center + + Args: + data: Normalized TensorDict. + + Returns: + Denormalized TensorDict. + """ + updates = {} + + for key in self.input_keys: + if key not in data.keys(): + raise KeyError(f"Field '{key}' not found in data") + + tensor = data[key] + + if self.method == "mean_std": + mean = self._means[key] + std = self._stds[key] + + updates[key] = tensor * (std + self.eps) + mean + + else: # min_max + min_val = self._mins[key] + max_val = self._maxs[key] + + center = (max_val + min_val) / 2.0 + half_range = (max_val - min_val) / 2.0 + updates[key] = tensor * (half_range + self.eps) + center + + return data.update(updates) + + def state_dict(self) -> dict[str, Any]: + """Return normalization parameters.""" + state = { + "input_keys": self.input_keys, + "method": self.method, + "eps": self.eps, + } + + if self.method == "mean_std": + state["means"] = {k: v.cpu() for k, v in self._means.items()} + state["stds"] = {k: v.cpu() for k, v in self._stds.items()} + else: # min_max + state["mins"] = {k: v.cpu() for k, v in self._mins.items()} + state["maxs"] = {k: v.cpu() for k, v in self._maxs.items()} + + return state + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load normalization parameters.""" + self.input_keys = state_dict["input_keys"] + self.method = state_dict.get( + "method", "mean_std" + ) # Default for backward compat + self.eps = state_dict["eps"] + + if self.method == "mean_std": + self._means = {k: v.clone() for k, v in state_dict["means"].items()} + self._stds = {k: v.clone() for k, v in state_dict["stds"].items()} + self._mins = None + self._maxs = None + else: # min_max + self._mins = {k: v.clone() for k, v in state_dict["mins"].items()} + self._maxs = {k: v.clone() for k, v in state_dict["maxs"].items()} + self._means = None + self._stds = None + + def extra_repr(self) -> str: + return f"method={self.method}, input_keys={self.input_keys}, eps={self.eps}" diff --git a/physicsnemo/datapipes/core/transforms/spatial.py b/physicsnemo/datapipes/core/transforms/spatial.py new file mode 100644 index 0000000000..8742549402 --- /dev/null +++ b/physicsnemo/datapipes/core/transforms/spatial.py @@ -0,0 +1,369 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Spatial transforms for mesh and grid processing. + +Provides generic transforms for spatial operations including bounding box +filtering, grid creation, k-NN neighbor computation, and center of mass calculation. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes.core.registry import register_transform +from physicsnemo.datapipes.core.transforms.base import Transform +from physicsnemo.nn.neighbors import knn + + +@register_transform() +class BoundingBoxFilter(Transform): + r""" + Filter points outside a spatial bounding box. + + Removes points that fall outside specified min/max bounds and applies + the same filtering to dependent arrays to maintain correspondence. + This is useful for focusing on specific regions of interest or removing + outliers from simulation data. + + Example: + >>> transform = BoundingBoxFilter( + ... input_keys=["volume_mesh_centers"], + ... bbox_min=torch.tensor([-1.0, -1.0, -1.0]), + ... bbox_max=torch.tensor([1.0, 1.0, 1.0]), + ... dependent_keys=["volume_fields", "sdf_nodes"] + ... ) + >>> sample = TensorDict({ + ... "volume_mesh_centers": torch.randn(10000, 3) * 2, # Some outside bbox + ... "volume_fields": torch.randn(10000, 4) + ... }) + >>> result = transform(sample) + >>> # Only points within bbox remain + """ + + def __init__( + self, + input_keys: list[str], + bbox_min: torch.Tensor, + bbox_max: torch.Tensor, + *, + dependent_keys: Optional[list[str]] = None, + ) -> None: + """ + Initialize the bounding box filter transform. + + Args: + input_keys: List of coordinate tensor keys to filter. + bbox_min: Minimum corner of bounding box, shape :math:`(3,)`. + bbox_max: Maximum corner of bounding box, shape :math:`(3,)`. + dependent_keys: Optional list of keys to filter using the same mask. + These maintain correspondence with the filtered coordinates. + """ + super().__init__() + self.input_keys = input_keys + self.bbox_min = bbox_min + self.bbox_max = bbox_max + self.dependent_keys = dependent_keys or [] + + def __call__(self, data: TensorDict) -> TensorDict: + """Apply bounding box filtering to the sample.""" + updates = {} + + for coord_key in self.input_keys: + if coord_key not in data: + continue + + coords = data[coord_key] + + # Move bbox to same device + bbox_min = self.bbox_min.to(coords.device) + bbox_max = self.bbox_max.to(coords.device) + + # Create mask for points inside bbox + ids_min = coords > bbox_min + ids_max = coords < bbox_max + ids_in_bbox = ids_min & ids_max + ids_in_bbox = ids_in_bbox.all(dim=-1) + + # Apply mask to coordinates + updates[coord_key] = coords[ids_in_bbox] + + # Apply same mask to dependent keys + for dep_key in self.dependent_keys: + if dep_key in data: + updates[dep_key] = data[dep_key][ids_in_bbox] + + return data.update(updates) + + def to(self, device: Union[torch.device, str]) -> "BoundingBoxFilter": + """Move bounding box tensors to the specified device.""" + super().to(device) + device = torch.device(device) if isinstance(device, str) else device + self.bbox_min = self.bbox_min.to(device) + self.bbox_max = self.bbox_max.to(device) + return self + + def __repr__(self) -> str: + return ( + f"BoundingBoxFilter(input_keys={self.input_keys}, " + f"dependent_keys={self.dependent_keys})" + ) + + +@register_transform() +class CreateGrid(Transform): + r""" + Create a regular 3D spatial grid. + + Generates a uniform grid spanning a bounding box, used for latent space + representations, interpolation grids, or structured spatial queries. + + Example: + >>> transform = CreateGrid( + ... output_key="grid", + ... resolution=(64, 64, 64), + ... bbox_min=torch.tensor([-1.0, -1.0, -1.0]), + ... bbox_max=torch.tensor([1.0, 1.0, 1.0]) + ... ) + >>> sample = TensorDict({}) + >>> result = transform(sample) + >>> print(result["grid"].shape) + torch.Size([262144, 3]) # 64*64*64 = 262144 + """ + + def __init__( + self, + output_key: str, + resolution: tuple[int, int, int], + bbox_min: torch.Tensor, + bbox_max: torch.Tensor, + ) -> None: + """ + Initialize the grid creation transform. + + Args: + output_key: Key to store the generated grid. + resolution: Grid resolution as (nx, ny, nz). + bbox_min: Minimum corner of bounding box, shape :math:`(3,)`. + bbox_max: Maximum corner of bounding box, shape :math:`(3,)`. + """ + super().__init__() + self.output_key = output_key + self.resolution = resolution + self.bbox_min = bbox_min + self.bbox_max = bbox_max + + def __call__(self, data: TensorDict) -> TensorDict: + """Create grid and add to sample.""" + device = data.device if data.device is not None else torch.device("cpu") + + # Move bbox to device + bbox_min = self.bbox_min.to(device) + bbox_max = self.bbox_max.to(device) + + nx, ny, nz = self.resolution + + # Create 1D arrays for each dimension + x = torch.linspace(bbox_min[0], bbox_max[0], nx, device=device) + y = torch.linspace(bbox_min[1], bbox_max[1], ny, device=device) + z = torch.linspace(bbox_min[2], bbox_max[2], nz, device=device) + + # Create meshgrid + xv, yv, zv = torch.meshgrid(x, y, z, indexing="ij") + + # Stack into grid of shape (nx*ny*nz, 3) + grid = torch.stack([xv.flatten(), yv.flatten(), zv.flatten()], dim=-1) + + return data.update({self.output_key: grid}) + + def to(self, device: Union[torch.device, str]) -> "CreateGrid": + """Move bounding box tensors to the specified device.""" + super().to(device) + device = torch.device(device) if isinstance(device, str) else device + self.bbox_min = self.bbox_min.to(device) + self.bbox_max = self.bbox_max.to(device) + return self + + def __repr__(self) -> str: + return f"CreateGrid(output_key={self.output_key}, resolution={self.resolution})" + + +@register_transform() +class KNNNeighbors(Transform): + r""" + Compute k-nearest neighbors in a point cloud. + + Finds the k nearest neighbors for each query point and extracts + corresponding coordinates and other attributes. Useful for local + feature aggregation in mesh networks and spatial interpolation. + + Example: + >>> transform = KNNNeighbors( + ... points_key="surface_mesh_centers", + ... queries_key="surface_mesh_centers_subsampled", + ... k=11, + ... output_prefix="surface_neighbors", + ... extract_keys=["surface_normals", "surface_areas"] + ... ) + >>> sample = TensorDict({ + ... "surface_mesh_centers": torch.randn(10000, 3), + ... "surface_mesh_centers_subsampled": torch.randn(1000, 3), + ... "surface_normals": torch.randn(10000, 3), + ... "surface_areas": torch.rand(10000) + ... }) + >>> result = transform(sample) + >>> # Creates: surface_neighbors_coords, surface_neighbors_normals, etc. + """ + + def __init__( + self, + points_key: str, + queries_key: str, + k: int, + *, + output_prefix: str = "neighbors", + extract_keys: Optional[list[str]] = None, + ) -> None: + """ + Initialize the k-NN transform. + + Args: + points_key: Key for reference points to search, shape :math:`(N, 3)`. + queries_key: Key for query points, shape :math:`(M, 3)`. + k: Number of nearest neighbors to find. + output_prefix: Prefix for output keys. + extract_keys: Optional list of keys to extract for neighbors + (e.g., ``["normals", "areas"]``). If None, only extracts coordinates. + """ + super().__init__() + self.points_key = points_key + self.queries_key = queries_key + self.k = k + self.output_prefix = output_prefix + self.extract_keys = extract_keys or [] + + def __call__(self, data: TensorDict) -> TensorDict: + """Compute k-NN and extract neighbor features.""" + if self.points_key not in data: + raise KeyError(f"Points key '{self.points_key}' not found") + if self.queries_key not in data: + raise KeyError(f"Queries key '{self.queries_key}' not found") + + points = data[self.points_key] + queries = data[self.queries_key] + + # Compute k-NN + neighbor_indices, neighbor_distances = knn( + points=points, + queries=queries, + k=self.k, + ) + + updates = {} + + # Store indices and distances + updates[f"{self.output_prefix}_indices"] = neighbor_indices + updates[f"{self.output_prefix}_distances"] = neighbor_distances + + # Extract neighbor coordinates (skip first, which is self) + neighbor_coords = points[neighbor_indices][:, 1:] + updates[f"{self.output_prefix}_coords"] = neighbor_coords + + # Extract additional features for neighbors + for key in self.extract_keys: + if key in data: + neighbor_features = data[key][neighbor_indices][:, 1:] + updates[f"{self.output_prefix}_{key}"] = neighbor_features + + return data.update(updates) + + def __repr__(self) -> str: + return ( + f"KNNNeighbors(points_key={self.points_key}, " + f"queries_key={self.queries_key}, k={self.k})" + ) + + +@register_transform() +class CenterOfMass(Transform): + r""" + Compute weighted center of mass for a point cloud. + + Calculates the center of mass using area or mass weights, typically + applied to mesh data where each point represents a cell with a specific area. + + Example: + >>> transform = CenterOfMass( + ... coords_key="stl_centers", + ... areas_key="stl_areas", + ... output_key="center_of_mass" + ... ) + >>> sample = Sample({ + ... "stl_centers": torch.randn(5000, 3), + ... "stl_areas": torch.rand(5000) + ... }) + >>> result = transform(sample) + >>> print(result["center_of_mass"].shape) + torch.Size([1, 3]) + """ + + def __init__( + self, + coords_key: str, + areas_key: str, + output_key: str, + ) -> None: + """ + Initialize the center of mass transform. + + Args: + coords_key: Key for coordinates, shape :math:`(N, 3)`. + areas_key: Key for area weights, shape :math:`(N,)`. + output_key: Key to store the computed center of mass, shape :math:`(1, 3)`. + """ + super().__init__() + self.coords_key = coords_key + self.areas_key = areas_key + self.output_key = output_key + + def __call__(self, data: TensorDict) -> TensorDict: + """Compute center of mass for the sample.""" + if self.coords_key not in data: + raise KeyError(f"Coordinates key '{self.coords_key}' not found") + if self.areas_key not in data: + raise KeyError(f"Areas key '{self.areas_key}' not found") + + coords = data[self.coords_key] + areas = data[self.areas_key] + + # Compute weighted center of mass + total_area = areas.sum() + weighted_coords = coords * areas.unsqueeze(-1) + center_of_mass = weighted_coords.sum(dim=0) / total_area + + # Ensure shape is (1, 3) + center_of_mass = center_of_mass.unsqueeze(0) + + return data.update({self.output_key: center_of_mass}) + + def __repr__(self) -> str: + return ( + f"CenterOfMass(coords_key={self.coords_key}, output_key={self.output_key})" + ) diff --git a/physicsnemo/datapipes/core/transforms/subsample.py b/physicsnemo/datapipes/core/transforms/subsample.py new file mode 100644 index 0000000000..66a5f18b23 --- /dev/null +++ b/physicsnemo/datapipes/core/transforms/subsample.py @@ -0,0 +1,269 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Subsampling transforms for point clouds and surfaces. + +Provides efficient subsampling methods for large datasets, including +Poisson disk sampling and weighted sampling. +""" + +from __future__ import annotations + +from typing import Literal, Optional + +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes.core.registry import register_transform +from physicsnemo.datapipes.core.transforms.base import Transform + + +def poisson_sample_indices_fixed(N: int, k: int, device=None) -> torch.Tensor: + """ + Near-uniform sampler of indices for very large arrays. + + This function provides nearly uniform sampling for cases where the number + of indices is very large (> 2^24) and :func:`torch.multinomial` cannot work. + Unlike using :func:`torch.randperm`, there is no need to materialize and + randomize the entire tensor of indices. + + The sampling uses exponentially distributed gaps to achieve near-uniform + coverage without replacement. + + Args: + N: Total number of available indices. + k: Number of indices to sample. + device: Device for the output tensor. + + Returns: + Tensor of shape :math:`(k,)` containing sampled indices. + + Example: + >>> indices = poisson_sample_indices_fixed(1000000, 10000) + >>> print(indices.shape) + torch.Size([10000]) + """ + # Draw exponential gaps off of random initializations + gaps = torch.rand(k, device=device).exponential_() + + summed = gaps.sum() + + # Normalize so total cumulative sum == N + gaps *= N / summed + + # Compute cumulative positions + idx = torch.cumsum(gaps, dim=0) + + # Shift down so range starts at 0 and ends below N + idx -= gaps[0] / 2 + + # Round to nearest integer index + idx = torch.clamp(idx.floor().long(), min=0, max=N - 1) + + return idx + + +def shuffle_array( + points: torch.Tensor, + n_points: int, + weights: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sample points with or without weights. + + Args: + points: Input tensor to sample from, shape :math:`(N, ...)` + n_points: Number of points to sample. + weights: Optional weights for sampling, shape :math:`(N,)`. + If None, uses uniform sampling. + + Returns: + Tuple of (sampled_points, indices) where: + - sampled_points: Sampled tensor, shape :math:`(n\\_points, ...)` + - indices: Selected indices, shape :math:`(n\\_points,)` + """ + N = points.shape[0] + device = points.device + + if N < n_points: + # If not enough points, return all points + indices = torch.arange(N, device=device) + return points, indices + + if weights is not None: + # Weighted sampling + indices = torch.multinomial(weights, n_points, replacement=False) + else: + # Uniform sampling + if N > 2**24: + # Use Poisson sampling for very large arrays + indices = poisson_sample_indices_fixed(N, n_points, device=device) + else: + # Use standard multinomial for smaller arrays + indices = torch.randperm(N, device=device)[:n_points] + + sampled_points = points[indices] + return sampled_points, indices + + +@register_transform() +class SubsamplePoints(Transform): + r""" + Subsample points from large point clouds or meshes. + + This transform applies coordinated subsampling to multiple tensor fields, + ensuring that the same points are selected across all specified keys. + Useful for downsampling large volumetric data or point clouds while + maintaining correspondence between coordinates and field values. + + Supports two sampling algorithms: + - ``"poisson_fixed"``: Near-uniform sampling for very large datasets (> 2^24 points) + - ``"uniform"``: Standard uniform sampling + + Optionally supports weighted sampling (e.g., area-weighted for surface meshes) + by providing a ``weights_key``. + + Example (uniform sampling): + >>> # Subsample volume data + >>> transform = SubsamplePoints( + ... input_keys=["volume_mesh_centers", "volume_fields"], + ... n_points=10000, + ... algorithm="poisson_fixed" + ... ) + >>> sample = TensorDict({ + ... "volume_mesh_centers": torch.randn(100000, 3), + ... "volume_fields": torch.randn(100000, 5) + ... }) + >>> result = transform(sample) + >>> print(result["volume_mesh_centers"].shape) + torch.Size([10000, 3]) + + Example (weighted sampling): + >>> # Area-weighted surface sampling + >>> transform = SubsamplePoints( + ... input_keys=["surface_mesh_centers", "surface_fields", "surface_normals"], + ... n_points=5000, + ... algorithm="uniform", + ... weights_key="surface_areas" + ... ) + >>> sample = Sample({ + ... "surface_mesh_centers": torch.randn(20000, 3), + ... "surface_fields": torch.randn(20000, 2), + ... "surface_normals": torch.randn(20000, 3), + ... "surface_areas": torch.rand(20000) + ... }) + >>> result = transform(sample) + >>> print(result["surface_mesh_centers"].shape) + torch.Size([5000, 3]) + + Note: + All specified keys must have the same size in their first dimension. + The same indices are applied to all keys to maintain correspondence. + """ + + def __init__( + self, + input_keys: list[str], + n_points: int, + *, + algorithm: Literal["poisson_fixed", "uniform"] = "poisson_fixed", + weights_key: Optional[str] = None, + ) -> None: + """ + Initialize the subsample transform. + + Args: + input_keys: List of tensor keys to subsample. All must have the same + first dimension size. + n_points: Number of points to sample. + algorithm: Sampling algorithm to use. Options: + - ``"poisson_fixed"``: Near-uniform sampling (default) + - ``"uniform"``: Standard uniform sampling + weights_key: Optional key for sampling weights (e.g., ``"surface_areas"`` + for area-weighted surface sampling). When provided, samples + are drawn according to the weights distribution. + """ + super().__init__() + self.input_keys = input_keys + self.n_points = n_points + self.algorithm = algorithm + self.weights_key = weights_key + + def __call__(self, data: TensorDict) -> TensorDict: + """Apply subsampling to the TensorDict.""" + if not self.input_keys: + return data + + # Check that all keys are present + for key in self.input_keys: + if key not in data.keys(): + raise KeyError( + f"Key '{key}' not found in data. " + f"Available keys: {list(data.keys())}" + ) + + # Get the first key to determine indices + first_key = self.input_keys[0] + first_tensor = data[first_key] + N = first_tensor.shape[0] + + # Check that all keys have the same first dimension + for key in self.input_keys[1:]: + if data[key].shape[0] != N: + raise ValueError( + f"All keys must have the same first dimension. " + f"Key '{first_key}' has {N}, but '{key}' has {data[key].shape[0]}" + ) + + # Skip if already fewer points than requested + if N <= self.n_points: + return data + + # Get weights if provided + weights = None + if self.weights_key is not None: + if self.weights_key not in data.keys(): + raise KeyError( + f"Weights key '{self.weights_key}' not found in data. " + f"Available keys: {list(data.keys())}" + ) + weights = data[self.weights_key] + + # Sample indices + device = first_tensor.device + if weights is not None: + # Weighted sampling + _, indices = shuffle_array(first_tensor, self.n_points, weights=weights) + elif self.algorithm == "poisson_fixed" and N > 2**24: + indices = poisson_sample_indices_fixed(N, self.n_points, device=device) + else: + # Use uniform sampling + indices = torch.randperm(N, device=device)[: self.n_points] + + # Apply indices to all keys + updates = {} + for key in self.input_keys: + updates[key] = data[key][indices] + + return data.update(updates) + + def __repr__(self) -> str: + weights_str = f", weights_key={self.weights_key}" if self.weights_key else "" + return ( + f"SubsamplePoints(input_keys={self.input_keys}, n_points={self.n_points}, " + f"algorithm={self.algorithm}{weights_str})" + ) diff --git a/test/datapipes/core/__init__.py b/test/datapipes/core/__init__.py new file mode 100644 index 0000000000..7fba337f8b --- /dev/null +++ b/test/datapipes/core/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test suite for datapipe.""" diff --git a/test/datapipes/core/conftest.py b/test/datapipes/core/conftest.py new file mode 100644 index 0000000000..6ca1224457 --- /dev/null +++ b/test/datapipes/core/conftest.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared fixtures for datapipe tests.""" + +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import torch +from tensordict import TensorDict + +# ============================================================================ +# Sample Fixtures +# ============================================================================ + + +@pytest.fixture +def simple_sample(): + """A simple sample with basic tensors.""" + data = TensorDict( + { + "x": torch.randn(100, 3), + "y": torch.randn(100), + } + ) + return data, {} + + +@pytest.fixture +def sample_with_metadata(): + """A sample with metadata.""" + data = TensorDict({"pressure": torch.randn(50), "velocity": torch.randn(50, 3)}) + metadata = {"filename": "test.h5", "index": 42} + return data, metadata + + +@pytest.fixture +def batch_of_samples(): + """Multiple samples for collation tests.""" + return [ + ( + TensorDict({"x": torch.randn(10, 3), "y": torch.randn(10)}), + {"idx": i}, + ) + for i in range(4) + ] + + +@pytest.fixture +def ragged_samples(): + """Samples with different sizes (for ConcatCollator).""" + return [ + (TensorDict({"points": torch.randn(100, 3)}), {"idx": 0}), + (TensorDict({"points": torch.randn(150, 3)}), {"idx": 1}), + (TensorDict({"points": torch.randn(80, 3)}), {"idx": 2}), + ] + + +# ============================================================================ +# Synthetic Data Fixtures +# ============================================================================ + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory that's cleaned up after the test.""" + path = Path(tempfile.mkdtemp()) + yield path + shutil.rmtree(path) + + +@pytest.fixture +def numpy_data_dir(temp_dir): + """Create a directory with .npz files for NumpyReader tests.""" + for i in range(10): + np.savez( + temp_dir / f"sample_{i:03d}.npz", + positions=np.random.randn(100, 3).astype(np.float32), + features=np.random.randn(100, 8).astype(np.float32), + label=np.array([i], dtype=np.int64), + ) + return temp_dir + + +@pytest.fixture +def numpy_npz_file(temp_dir): + """Create a single .npz file with multiple arrays.""" + path = temp_dir / "data.npz" + np.savez( + path, + images=np.random.randn(15, 32, 32).astype(np.float32), + labels=np.arange(15, dtype=np.int64), + ) + return path + + +@pytest.fixture +def hdf5_data_dir(temp_dir): + """Create a directory with .h5 files for HDF5Reader tests.""" + h5py = pytest.importorskip("h5py") + + for i in range(10): + with h5py.File(temp_dir / f"sample_{i:03d}.h5", "w") as f: + f.create_dataset("mesh", data=np.random.randn(200, 3).astype(np.float32)) + f.create_dataset("pressure", data=np.random.randn(200).astype(np.float32)) + f.create_dataset( + "velocity", data=np.random.randn(200, 3).astype(np.float32) + ) + + return temp_dir + + +@pytest.fixture +def hdf5_single_file(temp_dir): + """Create a single .h5 file with samples indexed along first dim.""" + h5py = pytest.importorskip("h5py") + + path = temp_dir / "data.h5" + with h5py.File(path, "w") as f: + f.create_dataset("inputs", data=np.random.randn(25, 64).astype(np.float32)) + f.create_dataset("targets", data=np.random.randn(25, 10).astype(np.float32)) + + return path + + +@pytest.fixture +def zarr_data_dir(temp_dir): + """Create a directory with .zarr groups for ZarrReader tests.""" + zarr = pytest.importorskip("zarr") + + for i in range(10): + group_path = temp_dir / f"sample_{i:03d}.zarr" + root = zarr.open(group_path, mode="w") + root.create_array("field_a", data=np.random.randn(50, 50).astype(np.float32)) + root.create_array("field_b", data=np.random.randn(50).astype(np.float32)) + + return temp_dir + + +@pytest.fixture +def zarr_single_group(temp_dir): + """Create a single .zarr group with samples indexed along first dim.""" + zarr = pytest.importorskip("zarr") + + path = temp_dir / "data.zarr" + root = zarr.open(path, mode="w") + root.create_array("data", data=np.random.randn(30, 16, 16).astype(np.float32)) + root.create_array( + "mask", data=np.random.randint(0, 2, (30, 16, 16)).astype(np.uint8) + ) + + return path + + +# ============================================================================ +# Device Fixtures +# ============================================================================ + + +@pytest.fixture +def device(): + """Return available device (cuda if available, else cpu).""" + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@pytest.fixture +def cuda_available(): + """Skip test if CUDA is not available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + return True diff --git a/test/datapipes/core/readers/test_numpy_consolidated.py b/test/datapipes/core/readers/test_numpy_consolidated.py new file mode 100644 index 0000000000..377b835be6 --- /dev/null +++ b/test/datapipes/core/readers/test_numpy_consolidated.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Tests for the NumpyReader. + +Tests reading from .npz files, directories, and coordinated subsampling. +""" + +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import torch + +from physicsnemo.datapipes.core.readers import NumpyReader + + +class TestNumpyReaderBasic: + """Basic functionality tests for NumpyReader.""" + + def setup_method(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.temp_path = Path(self.temp_dir) + + def teardown_method(self): + """Clean up temporary files.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_single_npz_file(self): + """Test reading from a single .npz file.""" + # Create test data + coords = np.random.randn(20, 3).astype(np.float32) + features = np.random.randn(20, 5).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords, features=features) + + # Create reader + reader = NumpyReader(npz_path, fields=["coords", "features"]) + + # Check properties + assert len(reader) == 20 + assert set(reader.field_names) == {"coords", "features"} + + # Load sample + data, metadata = reader[0] + assert "coords" in data + assert "features" in data + assert data["coords"].shape == (3,) + assert data["features"].shape == (5,) + + def test_single_npz_file_load_all_fields(self): + """Test reading all fields from a single .npz file when fields=None.""" + # Create test data + coords = np.random.randn(20, 3).astype(np.float32) + features = np.random.randn(20, 5).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords, features=features) + + # Create reader without specifying fields + reader = NumpyReader(npz_path) + + # Should load all fields + assert set(reader.field_names) == {"coords", "features"} + + data, metadata = reader[0] + assert "coords" in data + assert "features" in data + + def test_directory_of_npz_files(self): + """Test reading from a directory of .npz files.""" + # Create test data + for i in range(5): + coords = np.random.randn(100, 3).astype(np.float32) + features = np.random.randn(100, 2).astype(np.float32) + + npz_path = self.temp_path / f"sample_{i:03d}.npz" + np.savez(npz_path, coords=coords, features=features) + + # Create reader + reader = NumpyReader( + self.temp_path, file_pattern="sample_*.npz", fields=["coords", "features"] + ) + + # Check properties + assert len(reader) == 5 + assert set(reader.field_names) == {"coords", "features"} + + # Load sample + data, metadata = reader[0] + assert data["coords"].shape == (100, 3) + assert data["features"].shape == (100, 2) + + def test_directory_load_all_fields(self): + """Test reading all fields from directory when fields=None.""" + # Create test data + for i in range(3): + coords = np.random.randn(50, 3).astype(np.float32) + features = np.random.randn(50, 2).astype(np.float32) + + npz_path = self.temp_path / f"sample_{i:03d}.npz" + np.savez(npz_path, coords=coords, features=features) + + # Create reader without specifying fields + reader = NumpyReader(self.temp_path, file_pattern="sample_*.npz") + + # Should load all fields + assert set(reader.field_names) == {"coords", "features"} + + data, metadata = reader[0] + assert "coords" in data + assert "features" in data + + def test_default_values(self): + """Test optional keys with default values.""" + # Create test data with only some keys + coords = np.random.randn(10, 100, 3).astype(np.float32) + features = np.random.randn(10, 100, 2).astype(np.float32) + + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords, features=features) + # Note: no "normals" key + + # Create reader with optional key + default_normals = torch.zeros(100, 3) + reader = NumpyReader( + npz_path, + fields=["coords", "features", "normals"], + default_values={"normals": default_normals}, + ) + + # Load sample + data, metadata = reader[0] + assert "coords" in data + assert "features" in data + assert "normals" in data + + # Check that default was used + assert torch.allclose(data["normals"], default_normals) + + def test_unsupported_file_type(self): + """Test that .npy files raise an error.""" + npy_path = self.temp_path / "data.npy" + np.save(npy_path, np.random.randn(10, 3, 4)) + + with pytest.raises(ValueError, match="Unsupported file type"): + NumpyReader(npy_path) + + +class TestNumpyReaderCoordinatedSubsampling: + """Test coordinated subsampling functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.temp_path = Path(self.temp_dir) + + def teardown_method(self): + """Clean up temporary files.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_coordinated_subsampling_directory_npz(self): + """Test coordinated subsampling in directory mode.""" + # Create test data with large arrays + n_samples = 5 + n_points = 100000 + subsample_points = 10000 + + for i in range(n_samples): + coords = np.random.randn(n_points, 3).astype(np.float32) + features = np.random.randn(n_points, 4).astype(np.float32) + areas = np.random.rand(n_points).astype(np.float32) + + npz_path = self.temp_path / f"sample_{i:03d}.npz" + np.savez(npz_path, coords=coords, features=features, areas=areas) + + # Create reader with coordinated subsampling + reader = NumpyReader( + self.temp_path, + file_pattern="sample_*.npz", + fields=["coords", "features", "areas"], + coordinated_subsampling={ + "n_points": subsample_points, + "target_keys": ["coords", "features"], + }, + ) + + # Load sample + data, metadata = reader[0] + + # Check that subsampled arrays have correct size + assert data["coords"].shape == (subsample_points, 3) + assert data["features"].shape == (subsample_points, 4) + + # Non-target keys should be full size + assert data["areas"].shape == (n_points,) + + def test_supports_coordinated_subsampling(self): + """Test that coordinated subsampling is only supported in directory mode.""" + # Directory mode: supported + npz_path = self.temp_path / "sample_000.npz" + np.savez(npz_path, coords=np.random.randn(100, 3)) + + reader_dir = NumpyReader(self.temp_path, file_pattern="sample_*.npz") + assert reader_dir._supports_coordinated_subsampling is True + + # Single .npz file mode: not supported + single_npz_path = self.temp_path / "single.npz" + np.savez(single_npz_path, coords=np.random.randn(10, 100, 3)) + + reader_single = NumpyReader(single_npz_path) + assert reader_single._supports_coordinated_subsampling is False + + # Config is ignored for readers that don't support it + reader_with_config = NumpyReader( + single_npz_path, + coordinated_subsampling={"n_points": 50, "target_keys": ["coords"]}, + ) + # Config is stored but will be ignored during loading + assert reader_with_config._coordinated_subsampling_config is not None + + +class TestNumpyReaderMemoryManagement: + """Test memory management and cleanup.""" + + def setup_method(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.temp_path = Path(self.temp_dir) + + def teardown_method(self): + """Clean up temporary files.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_pin_memory(self): + """Test pin_memory functionality.""" + coords = np.random.randn(10, 3, 4).astype(np.float32) + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + # Create reader with pin_memory + reader = NumpyReader(npz_path, pin_memory=True) + data, metadata = reader[0] + + # Check that tensor is pinned + assert data["coords"].is_pinned() + + def test_close_handles(self): + """Test that file handles are properly closed.""" + coords = np.random.randn(20, 3).astype(np.float32) + npz_path = self.temp_path / "data.npz" + np.savez(npz_path, coords=coords) + + reader = NumpyReader(npz_path) + _ = reader[0] + + # Close should not raise + reader.close() + + # Should be able to open again + reader2 = NumpyReader(npz_path) + _ = reader2[0] + reader2.close() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/datapipes/core/test_collate.py b/test/datapipes/core/test_collate.py new file mode 100644 index 0000000000..2f5831b75d --- /dev/null +++ b/test/datapipes/core/test_collate.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for collation utilities.""" + +import pytest +import torch +from tensordict import TensorDict + +import physicsnemo.datapipes.core as dp + +# ============================================================================ +# DefaultCollator (stack-based) +# ============================================================================ + + +def test_default_collate_basic(batch_of_samples): + collator = dp.DefaultCollator(collate_metadata=False) + batched_data = collator(batch_of_samples) + + # 4 samples, each with shape (10, 3) -> (4, 10, 3) + assert batched_data["x"].shape == (4, 10, 3) + assert batched_data["y"].shape == (4, 10) + + +def test_default_collate_metadata(batch_of_samples): + collator = dp.DefaultCollator(collate_metadata=True) + batched_data, metadata_list = collator(batch_of_samples) + + # Metadata should be collected into lists + assert isinstance(metadata_list, list) + assert len(metadata_list) == 4 + assert [m["idx"] for m in metadata_list] == [0, 1, 2, 3] + + +def test_default_collate_empty_raises(): + collator = dp.DefaultCollator() + + with pytest.raises(ValueError, match="empty sequence"): + collator([]) + + +def test_default_collate_mismatched_keys_raises(): + samples = [ + (TensorDict({"x": torch.randn(10)}), {}), + (TensorDict({"y": torch.randn(10)}), {}), # Different key! + ] + collator = dp.DefaultCollator() + + with pytest.raises( + RuntimeError, match="sets of keys in the tensordicts to stack are exclusive" + ): + collator(samples) + + +def test_default_collate_mismatched_shapes_raises(): + samples = [ + (TensorDict({"x": torch.randn(10, 3)}), {}), + (TensorDict({"x": torch.randn(20, 3)}), {}), # Different shape! + ] + collator = dp.DefaultCollator() + + with pytest.raises( + RuntimeError, match="shapes of the tensors to stack is incompatible" + ): + collator(samples) + + +def test_default_collate_specific_keys(batch_of_samples): + collator = dp.DefaultCollator(keys=["x"], collate_metadata=False) + batched_data = collator(batch_of_samples) + + assert "x" in batched_data + assert "y" not in batched_data + + +def test_default_collate_different_stack_dim(): + samples = [ + (TensorDict({"x": torch.randn(3, 10)}), {}), + (TensorDict({"x": torch.randn(3, 10)}), {}), + ] + collator = dp.DefaultCollator(stack_dim=1, collate_metadata=False) + batched_data = collator(samples) + + # Stack along dim 1: (3, 10) -> (3, 2, 10) + assert batched_data["x"].shape == (3, 2, 10) + + +def test_default_collate_disable_metadata(): + samples = [ + (TensorDict({"x": torch.randn(10)}), {"idx": 0}), + (TensorDict({"x": torch.randn(10)}), {"idx": 1}), + ] + collator = dp.DefaultCollator(collate_metadata=False) + _ = collator(samples) + + +# ============================================================================ +# ConcatCollator (concat-based) +# ============================================================================ + + +def test_concat_collate_ragged(ragged_samples): + collator = dp.ConcatCollator(dim=0, add_batch_idx=True) + batched_data = collator(ragged_samples) + + # 100 + 150 + 80 = 330 points + assert batched_data["points"].shape == (330, 3) + assert batched_data["batch_idx"].shape == (330,) + + +def test_concat_batch_idx_values(ragged_samples): + collator = dp.ConcatCollator(dim=0, add_batch_idx=True) + batched_data = collator(ragged_samples) + + # First 100 should be 0, next 150 should be 1, last 80 should be 2 + assert (batched_data["batch_idx"][:100] == 0).all() + assert (batched_data["batch_idx"][100:250] == 1).all() + assert (batched_data["batch_idx"][250:] == 2).all() + + +def test_concat_collate_no_batch_idx(ragged_samples): + collator = dp.ConcatCollator(dim=0, add_batch_idx=False) + batched_data = collator(ragged_samples) + + assert "batch_idx" not in batched_data + + +def test_concat_collate_custom_batch_idx_key(ragged_samples): + collator = dp.ConcatCollator( + dim=0, + add_batch_idx=True, + batch_idx_key="sample_id", + ) + batched_data = collator(ragged_samples) + + assert "sample_id" in batched_data + assert "batch_idx" not in batched_data + + +def test_concat_collate_metadata(ragged_samples): + collator = dp.ConcatCollator(dim=0, collate_metadata=True) + batched_data, metadata_list = collator(ragged_samples) + + assert len(metadata_list) == 3 + assert [m["idx"] for m in metadata_list] == [0, 1, 2] + + +def test_concat_collate_empty_raises(): + collator = dp.ConcatCollator() + + with pytest.raises(ValueError, match="empty sequence"): + collator([]) + + +# ============================================================================ +# FunctionCollator +# ============================================================================ + + +def test_function_collator(): + def my_collate(samples): + # Just sum all tensors + data_list = [data for data, _ in samples] + total = sum(d["x"].sum() for d in data_list) + return TensorDict({"total": total.unsqueeze(0)}) + + samples = [ + (TensorDict({"x": torch.ones(10)}), {}), + (TensorDict({"x": torch.ones(10) * 2}), {}), + ] + + collator = dp.FunctionCollator(my_collate) + batched_data = collator(samples) + print(type(batched_data)) + + # 10*1 + 10*2 = 30 + assert batched_data["total"].item() == 30.0 + + +# ============================================================================ +# Collation convenience functions +# ============================================================================ + + +def test_default_collate_function(batch_of_samples): + batched_data = dp.default_collate(batch_of_samples) + + assert batched_data["x"].shape == (4, 10, 3) + + +def test_concat_collate_function(ragged_samples): + batched_data = dp.concat_collate(ragged_samples, dim=0, add_batch_idx=True) + + assert batched_data["points"].shape == (330, 3) + assert "batch_idx" in batched_data + + +def test_get_collator_none(): + collator = dp.get_collator(None) + assert isinstance(collator, dp.DefaultCollator) + + +def test_get_collator_instance(): + original = dp.ConcatCollator() + collator = dp.get_collator(original) + assert collator is original + + +def test_get_collator_function(): + def my_fn(samples): + return samples[0] + + collator = dp.get_collator(my_fn) + assert isinstance(collator, dp.FunctionCollator) + + +def test_get_collator_invalid(): + with pytest.raises(TypeError): + dp.get_collator("not a collator") + + +# ============================================================================ +# Collation with device +# ============================================================================ + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_collate_cuda_samples(): + samples = [ + (TensorDict({"x": torch.randn(10, device="cuda")}), {}), + (TensorDict({"x": torch.randn(10, device="cuda")}), {}), + ] + + batched_data, metadata_list = dp.default_collate(samples) + assert batched_data["x"].device.type == "cuda" diff --git a/test/datapipes/core/test_dataloader.py b/test/datapipes/core/test_dataloader.py new file mode 100644 index 0000000000..9a58738f9c --- /dev/null +++ b/test/datapipes/core/test_dataloader.py @@ -0,0 +1,428 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DataLoader class.""" + +import pytest +import torch +from tensordict import TensorDict + +import physicsnemo.datapipes.core as dp + +# ============================================================================ +# Basic DataLoader functionality +# ============================================================================ + + +def test_create_dataloader(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader(dataset, batch_size=2) + + # 10 samples / 2 batch_size = 5 batches + assert len(loader) == 5 + + +def test_iterate_batches(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader(dataset, batch_size=2) + + batches = list(loader) + assert len(batches) == 5 + + for batched_data in batches: + assert isinstance(batched_data, TensorDict) + assert batched_data["positions"].shape[0] == 2 # batch dim + + +def test_batch_collation(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader(dataset, batch_size=4) + + batched_data = next(iter(loader)) + + # Should have batch dimension + assert batched_data["positions"].shape == (4, 100, 3) + assert batched_data["features"].shape == (4, 100, 8) + + +def test_metadata_collation(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader(dataset, batch_size=3, collate_metadata=True) + + batched_data, metadata_list = next(iter(loader)) + + # Metadata should be lists + assert isinstance(metadata_list, list) + assert len(metadata_list) == 3 + assert [m["index"] for m in metadata_list] == [0, 1, 2] + + +def test_drop_last(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + # Without drop_last: 10 samples / 3 = 4 batches (last has 1) + loader_keep = dp.DataLoader(dataset, batch_size=3, drop_last=False) + assert len(loader_keep) == 4 + + # With drop_last: 10 samples / 3 = 3 batches + loader_drop = dp.DataLoader(dataset, batch_size=3, drop_last=True) + assert len(loader_drop) == 3 + + +def test_last_batch_smaller(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader(dataset, batch_size=3, drop_last=False) + + batches = list(loader) + last_batched_data = batches[-1] + + # 10 % 3 = 1, so last batch should have 1 sample + assert last_batched_data["positions"].shape[0] == 1 + + +# ============================================================================ +# DataLoader shuffling +# ============================================================================ + + +def test_shuffle_changes_order(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + # Collect indices from multiple epochs + torch.manual_seed(42) + loader = dp.DataLoader(dataset, batch_size=2, shuffle=True, collate_metadata=True) + + indices_epoch1 = [] + for batched_data, metadata_list in loader: + indices_epoch1.extend([m["index"] for m in metadata_list]) + + indices_epoch2 = [] + for batched_data, metadata_list in loader: + indices_epoch2.extend([m["index"] for m in metadata_list]) + + # Different epochs should (likely) have different orders + # Note: there's a tiny chance they're the same, but very unlikely + # We mainly check that shuffling doesn't break anything + assert set(indices_epoch1) == set(range(10)) + assert set(indices_epoch2) == set(range(10)) + + +def test_no_shuffle_preserves_order(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader(dataset, batch_size=2, shuffle=False, collate_metadata=True) + + indices = [] + for batched_data, metadata_list in loader: + indices.extend([m["index"] for m in metadata_list]) + + assert indices == list(range(10)) + + +# ============================================================================ +# DataLoader prefetching +# ============================================================================ + + +def test_prefetch_disabled(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader( + dataset, + batch_size=2, + prefetch_factor=0, # Disabled + ) + + batches = list(loader) + assert len(batches) == 5 + + +def test_prefetch_enabled(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader( + dataset, + batch_size=2, + prefetch_factor=2, + use_streams=False, # CPU mode + ) + + batches = list(loader) + assert len(batches) == 5 + + for batched_data in batches: + assert batched_data["positions"].shape[0] == 2 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_prefetch_with_streams(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset(reader, device="cuda:0") + loader = dp.DataLoader( + dataset, + batch_size=2, + prefetch_factor=2, + num_streams=4, + use_streams=True, + ) + + batches = list(loader) + assert len(batches) == 5 + + for batched_data in batches: + assert batched_data["positions"].device.type == "cuda" + + +def test_disable_prefetch(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader( + dataset, + batch_size=2, + prefetch_factor=2, + ) + + loader.disable_prefetch() + + # Should still work in sync mode + batches = list(loader) + assert len(batches) == 5 + + +# ============================================================================ +# DataLoader custom collation +# ============================================================================ + + +def test_default_collation(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader(dataset, batch_size=3) + + batched_data = next(iter(loader)) + + # Default collation stacks tensors + assert batched_data["positions"].shape == (3, 100, 3) + + +def test_concat_collation(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader( + dataset, + batch_size=3, + collate_fn=dp.ConcatCollator(dim=0, add_batch_idx=True), + ) + + batched_data = next(iter(loader)) + + # Concat collation concatenates along dim 0 + assert batched_data["positions"].shape == (300, 3) # 3 * 100 points + assert "batch_idx" in batched_data + assert batched_data["batch_idx"].shape == (300,) + + +def test_custom_collate_fn(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + def my_collate(samples): + # Just return first sample + return samples[0] + + loader = dp.DataLoader( + dataset, batch_size=3, collate_fn=my_collate, collate_metadata=True + ) + + result = next(iter(loader)) + + # Should be single sample tuple, not batched + data, metadata = result + assert data["positions"].shape == (100, 3) + + +# ============================================================================ +# DataLoader with custom samplers +# ============================================================================ + + +def test_sequential_sampler(numpy_data_dir): + from torch.utils.data import SequentialSampler + + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + sampler = SequentialSampler(dataset) + loader = dp.DataLoader( + dataset, batch_size=2, sampler=sampler, collate_metadata=True + ) + + indices = [] + for batched_data, metadata_list in loader: + indices.extend([m["index"] for m in metadata_list]) + + assert indices == list(range(10)) + + +def test_random_sampler(numpy_data_dir): + from torch.utils.data import RandomSampler + + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + torch.manual_seed(123) + sampler = RandomSampler(dataset) + loader = dp.DataLoader( + dataset, batch_size=2, sampler=sampler, collate_metadata=True + ) + + indices = [] + for batched_data, metadata_list in loader: + indices.extend([m["index"] for m in metadata_list]) + + # All indices present, but possibly shuffled + assert set(indices) == set(range(10)) + + +def test_subset_sampler(numpy_data_dir): + from torch.utils.data import SubsetRandomSampler + + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + # Only use indices 0, 2, 4, 6, 8 + indices = [0, 2, 4, 6, 8] + sampler = SubsetRandomSampler(indices) + loader = dp.DataLoader( + dataset, batch_size=2, sampler=sampler, collate_metadata=True + ) + + seen_indices = [] + for batched_data, metadata_list in loader: + seen_indices.extend([m["index"] for m in metadata_list]) + + assert set(seen_indices) == set(indices) + + +def test_set_epoch(numpy_data_dir): + """Test set_epoch for DistributedSampler compatibility.""" + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader(dataset, batch_size=2) + + # Should not raise even if sampler doesn't have set_epoch + loader.set_epoch(0) + loader.set_epoch(1) + + +# ============================================================================ +# End-to-end tests +# ============================================================================ + + +def test_training_loop_simulation(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset( + reader, + transforms=dp.SubsamplePoints( + input_keys=["positions", "features"], n_points=50 + ), + ) + loader = dp.DataLoader( + dataset, + batch_size=2, + shuffle=True, + ) + + # Simulate 3 epochs + for epoch in range(3): + loader.set_epoch(epoch) + + total_samples = 0 + for batched_data in loader: + batch_size = batched_data["positions"].shape[0] + total_samples += batch_size + + # Verify transform was applied + assert batched_data["positions"].shape[1] == 50 + + assert total_samples == 10 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_gpu_training_loop(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset( + reader, + device="cuda:0", + transforms=[ + dp.Normalize( + input_keys=["positions"], + method="mean_std", + means={"positions": 0.0}, + stds={"positions": 1.0}, + ), + ], + ) + loader = dp.DataLoader( + dataset, + batch_size=4, + shuffle=True, + prefetch_factor=2, + num_streams=4, + ) + + for batched_data in loader: + assert batched_data["positions"].device.type == "cuda" + + # Simulate forward pass + _ = batched_data["positions"].mean() + + torch.cuda.synchronize() + + +# ============================================================================ +# DataLoader errors +# ============================================================================ + + +def test_invalid_batch_size(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + with pytest.raises(ValueError, match="batch_size must be >= 1"): + dp.DataLoader(dataset, batch_size=0) + + +# ============================================================================ +# DataLoader repr +# ============================================================================ + + +def test_dataloader_repr(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader(dataset, batch_size=4) + + repr_str = repr(loader) + assert "DataLoader" in repr_str + assert "batch_size=4" in repr_str diff --git a/test/datapipes/core/test_dataset.py b/test/datapipes/core/test_dataset.py new file mode 100644 index 0000000000..873ea0b7d4 --- /dev/null +++ b/test/datapipes/core/test_dataset.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Tests for Dataset class.""" + +import pytest +import torch +from tensordict import TensorDict + +import physicsnemo.datapipes.core as dp + +# ============================================================================ +# Basic Dataset functionality +# ============================================================================ + + +def test_create_dataset(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + assert len(dataset) == 10 + + +def test_dataset_get_sample(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + data, metadata = dataset[0] + assert isinstance(data, TensorDict) + assert "positions" in data + + +def test_dataset_iteration(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + samples = list(dataset) + assert len(samples) == 10 + + +def test_dataset_field_names(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + assert "positions" in dataset.field_names + assert "features" in dataset.field_names + + +def test_dataset_context_manager(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + with dp.Dataset(reader) as dataset: + data, metadata = dataset[0] + assert "positions" in data + + +# ============================================================================ +# Dataset with transforms +# ============================================================================ + + +def test_dataset_single_transform(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + norm = dp.Normalize( + input_keys=["positions"], + method="mean_std", + means={"positions": 0.0}, + stds={"positions": 1.0}, + ) + dataset = dp.Dataset(reader, transforms=norm) + + data, metadata = dataset[0] + assert "positions" in data + + +def test_dataset_transform_list(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset( + reader, + transforms=[ + dp.SubsamplePoints(input_keys=["positions", "features"], n_points=50), + ], + ) + + data, metadata = dataset[0] + assert data["positions"].shape[0] == 50 + assert data["features"].shape[0] == 50 + + +def test_dataset_compose_transforms(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset( + reader, + transforms=dp.Compose( + [ + dp.SubsamplePoints(input_keys=["positions", "features"], n_points=50), + dp.Normalize( + input_keys=["positions"], + method="mean_std", + means={"positions": 0.0}, + stds={"positions": 1.0}, + ), + ] + ), + ) + + data, metadata = dataset[0] + assert data["positions"].shape[0] == 50 + + +def test_dataset_empty_transforms_list(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader, transforms=[]) + + data, metadata = dataset[0] + # Should work, no transforms applied + assert "positions" in data + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_dataset_to_device_transform(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset( + reader, + device="cuda:0", + ) + + data, metadata = dataset[0] + assert data["positions"].device.type == "cuda" + + +# ============================================================================ +# Dataset prefetching +# ============================================================================ + + +def test_prefetch_single(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + # Prefetch index 0 + dataset.prefetch(0) + + # Should have 1 prefetch in flight + assert dataset.prefetch_count >= 0 # May complete quickly + + # Get should use prefetched result + data, metadata = dataset[0] + assert "positions" in data + + +def test_prefetch_batch(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + # Prefetch multiple indices + dataset.prefetch_batch([0, 1, 2, 3]) + + # Get samples + for i in range(4): + data, metadata = dataset[i] + assert metadata["index"] == i + + +def test_prefetch_non_prefetched_index(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + # Prefetch index 0 + dataset.prefetch(0) + + # Get non-prefetched index (should load synchronously) + data, metadata = dataset[5] + assert metadata["index"] == 5 + + +def test_prefetch_cancel(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + dataset.prefetch_batch([0, 1, 2, 3]) + dataset.cancel_prefetch() + + # Prefetch count should be 0 after cancel + assert dataset.prefetch_count == 0 + + +def test_prefetch_cancel_specific(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + dataset.prefetch(0) + dataset.prefetch(1) + dataset.cancel_prefetch(0) + + # Should still be able to get index 1 from prefetch + # and index 0 synchronously + data0, metadata0 = dataset[0] + data1, metadata1 = dataset[1] + + assert metadata0["index"] == 0 + assert metadata1["index"] == 1 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_prefetch_with_stream(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset( + reader, + device="cuda:0", + ) + + stream = torch.cuda.Stream() + dataset.prefetch(0, stream=stream) + + data, metadata = dataset[0] + assert data["positions"].device.type == "cuda" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_prefetch_batch_with_streams(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset( + reader, + device="cuda:0", + ) + + streams = [torch.cuda.Stream() for _ in range(4)] + dataset.prefetch_batch([0, 1, 2, 3], streams=streams) + + for i in range(4): + data, metadata = dataset[i] + assert data["positions"].device.type == "cuda" + + +def test_prefetch_with_transforms(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset( + reader, + transforms=dp.SubsamplePoints( + input_keys=["positions", "features"], n_points=50 + ), + ) + + dataset.prefetch(0) + data, metadata = dataset[0] + + # Transform should have been applied + assert data["positions"].shape[0] == 50 + + +def test_close_stops_prefetch(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + dataset.prefetch_batch([0, 1, 2, 3]) + dataset.close() + + # Should not raise, prefetch should be stopped + assert dataset.prefetch_count == 0 + + +# ============================================================================ +# Dataset errors +# ============================================================================ + + +def test_invalid_reader_type(): + with pytest.raises(TypeError, match="must be a Reader"): + dp.Dataset("not a reader") + + +def test_invalid_transforms_type(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + + with pytest.raises(TypeError, match="must be Transform"): + dp.Dataset(reader, transforms="not a transform") + + +# ============================================================================ +# Dataset repr +# ============================================================================ + + +def test_dataset_repr(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + + repr_str = repr(dataset) + assert "Dataset" in repr_str + assert "NumpyReader" in repr_str + + +def test_dataset_repr_with_transforms(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset( + reader, + transforms=dp.Normalize( + input_keys=["positions"], + method="mean_std", + means={"positions": 0.0}, + stds={"positions": 1.0}, + ), + ) + + repr_str = repr(dataset) + assert "Normalize" in repr_str diff --git a/test/datapipes/core/test_readers.py b/test/datapipes/core/test_readers.py new file mode 100644 index 0000000000..6ac365c6c8 --- /dev/null +++ b/test/datapipes/core/test_readers.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Tests for data readers.""" + +import pytest +import torch +from tensordict import TensorDict + +import physicsnemo.datapipes.core as dp + +# ============================================================================ +# NumpyReader - Directory mode +# ============================================================================ + + +def test_numpy_load_from_directory(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, file_pattern="sample_*.npz") + + assert len(reader) == 10 + assert "positions" in reader.field_names + assert "features" in reader.field_names + + +def test_numpy_get_sample(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + data, metadata = reader[0] + + assert isinstance(data, TensorDict) + assert data["positions"].shape == (100, 3) + assert data["features"].shape == (100, 8) + assert data["positions"].dtype == torch.float32 + + +def test_numpy_sample_metadata(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + data, metadata = reader[0] + + assert "index" in metadata + assert metadata["index"] == 0 + assert "source_filename" in metadata + + +def test_numpy_negative_indexing(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + + last_data, _ = reader[-1] + also_last_data, _ = reader[9] + + torch.testing.assert_close(last_data["positions"], also_last_data["positions"]) + + +def test_numpy_index_out_of_range(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + + with pytest.raises(IndexError): + _ = reader[100] + + +def test_numpy_iteration(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + + samples = list(reader) + assert len(samples) == 10 + for i, (data, metadata) in enumerate(samples): + assert metadata["index"] == i + + +def test_numpy_select_fields(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, fields=["positions"]) + data, metadata = reader[0] + + assert "positions" in data + assert "features" not in data + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_numpy_pin_memory(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + data, metadata = reader[0] + + assert data["positions"].is_pinned() + + +def test_numpy_context_manager(numpy_data_dir): + with dp.NumpyReader(numpy_data_dir) as reader: + data, metadata = reader[0] + assert "positions" in data + + +# ============================================================================ +# NumpyReader - Single file mode +# ============================================================================ + + +def test_numpy_load_npz(numpy_npz_file): + reader = dp.NumpyReader(numpy_npz_file) + + assert len(reader) == 15 + assert "images" in reader.field_names + assert "labels" in reader.field_names + + +def test_numpy_get_sample_npz(numpy_npz_file): + reader = dp.NumpyReader(numpy_npz_file) + data, metadata = reader[5] + + assert data["images"].shape == (32, 32) + assert data["labels"].item() == 5 + + +# ============================================================================ +# HDF5Reader +# ============================================================================ + + +def test_hdf5_load_from_directory(hdf5_data_dir): + reader = dp.HDF5Reader(hdf5_data_dir, file_pattern="sample_*.h5") + + assert len(reader) == 10 + assert "mesh" in reader.field_names + assert "pressure" in reader.field_names + + +def test_hdf5_get_sample(hdf5_data_dir): + reader = dp.HDF5Reader(hdf5_data_dir) + data, metadata = reader[0] + + assert data["mesh"].shape == (200, 3) + assert data["pressure"].shape == (200,) + assert data["velocity"].shape == (200, 3) + + +def test_hdf5_single_file_mode(hdf5_single_file): + reader = dp.HDF5Reader(hdf5_single_file) + + assert len(reader) == 25 + assert "inputs" in reader.field_names + + +def test_hdf5_get_sample_single_file(hdf5_single_file): + reader = dp.HDF5Reader(hdf5_single_file) + data, metadata = reader[10] + + assert data["inputs"].shape == (64,) + assert data["targets"].shape == (10,) + + +def test_hdf5_select_fields(hdf5_data_dir): + reader = dp.HDF5Reader(hdf5_data_dir, fields=["mesh"]) + data, metadata = reader[0] + + assert "mesh" in data + assert "pressure" not in data + + +def test_hdf5_close(hdf5_single_file): + reader = dp.HDF5Reader(hdf5_single_file) + _ = reader[0] + reader.close() + # Should not raise on close + + +# ============================================================================ +# ZarrReader +# ============================================================================ + + +def test_zarr_load_from_directory(zarr_data_dir): + reader = dp.ZarrReader(zarr_data_dir, group_pattern="sample_*.zarr") + + assert len(reader) == 10 + assert "field_a" in reader.field_names + assert "field_b" in reader.field_names + + +def test_zarr_get_sample(zarr_data_dir): + reader = dp.ZarrReader(zarr_data_dir) + data, metadata = reader[0] + + assert data["field_a"].shape == (50, 50) + assert data["field_b"].shape == (50,) + + +def test_zarr_single_group_mode(zarr_single_group): + reader = dp.ZarrReader(zarr_single_group) + + assert len(reader) == 30 + + +def test_zarr_get_sample_single_group(zarr_single_group): + reader = dp.ZarrReader(zarr_single_group) + data, metadata = reader[5] + + assert data["data"].shape == (16, 16) + assert data["mask"].shape == (16, 16) + + +# ============================================================================ +# Reader errors +# ============================================================================ + + +def test_numpy_empty_directory(temp_dir): + with pytest.raises(ValueError, match="No files matching"): + dp.NumpyReader(temp_dir, file_pattern="*.npz") + + +def test_numpy_unsupported_extension(temp_dir): + # Create a file with wrong extension + (temp_dir / "data.txt").write_text("hello") + + with pytest.raises(ValueError, match="Unsupported file type"): + dp.NumpyReader(temp_dir / "data.txt") + + +# ============================================================================ +# Reader repr +# ============================================================================ + + +def test_numpy_reader_repr(numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + repr_str = repr(reader) + + assert "NumpyReader" in repr_str + assert "len=10" in repr_str + + +def test_hdf5_reader_repr(hdf5_data_dir): + reader = dp.HDF5Reader(hdf5_data_dir) + repr_str = repr(reader) + + assert "HDF5Reader" in repr_str + assert "directory" in repr_str diff --git a/test/datapipes/core/test_transforms.py b/test/datapipes/core/test_transforms.py new file mode 100644 index 0000000000..487b8e0ace --- /dev/null +++ b/test/datapipes/core/test_transforms.py @@ -0,0 +1,839 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for transforms.""" + +import tempfile +import warnings +from pathlib import Path + +import numpy as np +import pytest +import torch +from tensordict import TensorDict + +import physicsnemo.datapipes.core as dp + +# ============================================================================ +# Normalize transform +# ============================================================================ + + +def test_normalize_single_field(): + sample = TensorDict({"x": torch.tensor([10.0, 20.0, 30.0])}) + norm = dp.Normalize( + input_keys=["x"], + method="mean_std", + means={"x": 20.0}, + stds={"x": 10.0}, + ) + + result = norm(sample) + expected = torch.tensor([-1.0, 0.0, 1.0]) + torch.testing.assert_close(result["x"], expected, atol=1e-6, rtol=1e-6) + + +def test_normalize_multiple_fields(): + sample = TensorDict( + { + "a": torch.tensor([100.0]), + "b": torch.tensor([50.0]), + } + ) + norm = dp.Normalize( + input_keys=["a", "b"], + method="mean_std", + means={"a": 100.0, "b": 0.0}, + stds={"a": 10.0, "b": 50.0}, + ) + + result = norm(sample) + torch.testing.assert_close(result["a"], torch.tensor([0.0]), atol=1e-6, rtol=1e-6) + torch.testing.assert_close(result["b"], torch.tensor([1.0]), atol=1e-6, rtol=1e-6) + + +def test_normalize_preserves_other_fields(): + sample = TensorDict( + { + "x": torch.tensor([10.0]), + "y": torch.tensor([999.0]), + } + ) + norm = dp.Normalize( + input_keys=["x"], method="mean_std", means={"x": 0.0}, stds={"x": 1.0} + ) + + result = norm(sample) + assert "y" in result + torch.testing.assert_close(result["y"], torch.tensor([999.0])) + + +def test_normalize_inverse(): + sample = TensorDict({"x": torch.tensor([1.0, 2.0, 3.0])}) + norm = dp.Normalize( + input_keys=["x"], method="mean_std", means={"x": 10.0}, stds={"x": 2.0} + ) + + normalized = norm(sample) + denormalized = norm.inverse(normalized) + + torch.testing.assert_close(denormalized["x"], sample["x"], atol=1e-5, rtol=1e-5) + + +def test_normalize_scalar_mean_std(): + sample = TensorDict( + { + "a": torch.tensor([10.0]), + "b": torch.tensor([20.0]), + } + ) + norm = dp.Normalize(input_keys=["a", "b"], method="mean_std", means=0.0, stds=10.0) + + result = norm(sample) + torch.testing.assert_close(result["a"], torch.tensor([1.0]), atol=1e-6, rtol=1e-6) + torch.testing.assert_close(result["b"], torch.tensor([2.0]), atol=1e-6, rtol=1e-6) + + +def test_normalize_missing_field_raises(): + sample = TensorDict({"x": torch.randn(10)}) + norm = dp.Normalize( + input_keys=["y"], method="mean_std", means={"y": 0.0}, stds={"y": 1.0} + ) + + with pytest.raises(KeyError): + norm(sample) + + +def test_normalize_empty_fields_raises(): + with pytest.raises(ValueError, match="cannot be empty"): + dp.Normalize(input_keys=[], method="mean_std", means={}, stds={}) + + +def test_normalize_missing_mean_raises(): + with pytest.raises(ValueError, match="Mean not provided"): + dp.Normalize( + input_keys=["x", "y"], + method="mean_std", + means={"x": 0.0}, + stds={"x": 1.0, "y": 1.0}, + ) + + +def test_normalize_state_dict(): + norm = dp.Normalize( + input_keys=["x"], + method="mean_std", + means={"x": 5.0}, + stds={"x": 2.0}, + ) + state = norm.state_dict() + + assert state["input_keys"] == ["x"] + assert state["method"] == "mean_std" + assert "x" in state["means"] + assert "x" in state["stds"] + + +# ============================================================================ +# Min-Max Scaling Tests +# ============================================================================ + + +def test_normalize_minmax_single_field(): + """Test min-max normalization normalizes to [-1, 1].""" + sample = TensorDict({"x": torch.tensor([0.0, 50.0, 100.0])}) + norm = dp.Normalize( + input_keys=["x"], + method="min_max", + mins={"x": 0.0}, + maxs={"x": 100.0}, + ) + + result = norm(sample) + # min=0, max=100 -> center=50, half_range=50 + # Values: (0-50)/50=-1, (50-50)/50=0, (100-50)/50=1 + expected = torch.tensor([-1.0, 0.0, 1.0]) + torch.testing.assert_close(result["x"], expected, atol=1e-6, rtol=1e-6) + + +def test_normalize_minmax_multiple_fields(): + """Test min-max normalization with multiple fields.""" + sample = TensorDict( + { + "pressure": torch.tensor([100000.0]), + "velocity": torch.tensor([0.0]), + } + ) + norm = dp.Normalize( + input_keys=["pressure", "velocity"], + method="min_max", + mins={"pressure": 90000.0, "velocity": -50.0}, + maxs={"pressure": 110000.0, "velocity": 50.0}, + ) + + result = norm(sample) + # pressure: center=100000, half_range=10000 -> (100000-100000)/10000 = 0 + # velocity: center=0, half_range=50 -> (0-0)/50 = 0 + torch.testing.assert_close( + result["pressure"], torch.tensor([0.0]), atol=1e-6, rtol=1e-6 + ) + torch.testing.assert_close( + result["velocity"], torch.tensor([0.0]), atol=1e-6, rtol=1e-6 + ) + + +def test_normalize_minmax_inverse(): + """Test inverse min-max normalization.""" + sample = TensorDict({"x": torch.tensor([25.0, 50.0, 75.0])}) + norm = dp.Normalize( + input_keys=["x"], + method="min_max", + mins={"x": 0.0}, + maxs={"x": 100.0}, + ) + + normalized = norm(sample) + denormalized = norm.inverse(normalized) + + torch.testing.assert_close(denormalized["x"], sample["x"], atol=1e-5, rtol=1e-5) + + +def test_normalize_minmax_scalar_values(): + """Test min-max with scalar min/max applied to all fields.""" + sample = TensorDict( + { + "a": torch.tensor([0.0]), + "b": torch.tensor([100.0]), + } + ) + norm = dp.Normalize(input_keys=["a", "b"], method="min_max", mins=0.0, maxs=100.0) + + result = norm(sample) + # Both use center=50, half_range=50 + torch.testing.assert_close(result["a"], torch.tensor([-1.0]), atol=1e-6, rtol=1e-6) + torch.testing.assert_close(result["b"], torch.tensor([1.0]), atol=1e-6, rtol=1e-6) + + +def test_normalize_minmax_edge_case_same_min_max(): + """Test min-max when min == max (should use eps to avoid division by zero).""" + sample = TensorDict({"x": torch.tensor([50.0])}) + norm = dp.Normalize( + input_keys=["x"], + method="min_max", + mins={"x": 50.0}, + maxs={"x": 50.0}, + eps=1e-8, + ) + + result = norm(sample) + # center=50, half_range=0 -> (50-50)/(0+eps) ≈ 0 + torch.testing.assert_close(result["x"], torch.tensor([0.0]), atol=1e-5, rtol=1e-5) + + +def test_normalize_minmax_state_dict(): + """Test state_dict for min-max normalization.""" + norm = dp.Normalize( + input_keys=["x"], + method="min_max", + mins={"x": 0.0}, + maxs={"x": 100.0}, + ) + state = norm.state_dict() + + assert state["input_keys"] == ["x"] + assert state["method"] == "min_max" + assert "x" in state["mins"] + assert "x" in state["maxs"] + assert "means" not in state + assert "stds" not in state + + +def test_normalize_minmax_load_state_dict(): + """Test loading state_dict for min-max normalization.""" + state = { + "input_keys": ["x"], + "method": "min_max", + "mins": {"x": torch.tensor(0.0)}, + "maxs": {"x": torch.tensor(100.0)}, + "eps": 1e-8, + } + + norm = dp.Normalize( + input_keys=["x"], method="min_max", mins={"x": 50.0}, maxs={"x": 150.0} + ) + norm.load_state_dict(state) + + sample = TensorDict({"x": torch.tensor([50.0])}) + result = norm(sample) + # Should use loaded mins/maxs: center=50, half_range=50 + expected = torch.tensor([0.0]) + torch.testing.assert_close(result["x"], expected, atol=1e-6, rtol=1e-6) + + +# ============================================================================ +# File Loading Tests +# ============================================================================ + + +def test_normalize_load_from_npz_mean_std(): + """Test loading mean_std normalization from .npz file.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test npz file + npz_path = Path(tmpdir) / "stats.npz" + stats_data = { + "pressure": {"mean": np.array(100000.0), "std": np.array(10000.0)}, + "velocity": {"mean": np.array(0.0), "std": np.array(10.0)}, + } + np.savez(npz_path, **stats_data) + + # Load normalizer + norm = dp.Normalize( + input_keys=["pressure", "velocity"], + method="mean_std", + stats_file=npz_path, + ) + + sample = TensorDict( + { + "pressure": torch.tensor([110000.0]), + "velocity": torch.tensor([10.0]), + } + ) + result = norm(sample) + + # pressure: (110000 - 100000) / 10000 = 1.0 + # velocity: (10 - 0) / 10 = 1.0 + torch.testing.assert_close( + result["pressure"], torch.tensor([1.0]), atol=1e-6, rtol=1e-6 + ) + torch.testing.assert_close( + result["velocity"], torch.tensor([1.0]), atol=1e-6, rtol=1e-6 + ) + + +def test_normalize_load_from_npz_min_max(): + """Test loading min_max normalization from .npz file.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test npz file + npz_path = Path(tmpdir) / "stats.npz" + stats_data = { + "x": {"min": np.array(0.0), "max": np.array(100.0)}, + "y": {"min": np.array(-50.0), "max": np.array(50.0)}, + } + np.savez(npz_path, **stats_data) + + # Load normalizer + norm = dp.Normalize( + input_keys=["x", "y"], + method="min_max", + stats_file=npz_path, + ) + + sample = TensorDict( + { + "x": torch.tensor([50.0]), + "y": torch.tensor([0.0]), + } + ) + result = norm(sample) + + # x: center=50, half_range=50 -> (50-50)/50 = 0 + # y: center=0, half_range=50 -> (0-0)/50 = 0 + torch.testing.assert_close( + result["x"], torch.tensor([0.0]), atol=1e-6, rtol=1e-6 + ) + torch.testing.assert_close( + result["y"], torch.tensor([0.0]), atol=1e-6, rtol=1e-6 + ) + + +def test_normalize_load_file_not_found(): + """Test error handling when stats file doesn't exist.""" + with pytest.raises(FileNotFoundError, match="not found"): + dp.Normalize( + input_keys=["x"], + method="mean_std", + stats_file="nonexistent_file.npz", + ) + + +def test_normalize_load_missing_field_in_file(): + """Test error handling when required field is missing in stats file.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create npz file with only one field + npz_path = Path(tmpdir) / "stats.npz" + stats_data = { + "x": {"mean": np.array(0.0), "std": np.array(1.0)}, + } + np.savez(npz_path, **stats_data) + + # Try to load normalizer expecting two fields + with pytest.raises(ValueError, match="not found in stats file"): + dp.Normalize( + input_keys=["x", "y"], + method="mean_std", + stats_file=npz_path, + ) + + +def test_normalize_file_override_with_direct_params(): + """Test that direct parameters override file parameters.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test npz file with some values + npz_path = Path(tmpdir) / "stats.npz" + stats_data = { + "x": {"mean": np.array(100.0), "std": np.array(10.0)}, + } + np.savez(npz_path, **stats_data) + + # Load with direct override of mean + norm = dp.Normalize( + input_keys=["x"], + method="mean_std", + stats_file=npz_path, + means={"x": 50.0}, # Override mean from file + ) + + sample = TensorDict({"x": torch.tensor([60.0])}) + result = norm(sample) + + # Should use direct mean (50), but file std (10) + # (60 - 50) / 10 = 1.0 + torch.testing.assert_close( + result["x"], torch.tensor([1.0]), atol=1e-6, rtol=1e-6 + ) + + +# ============================================================================ +# Backward Compatibility Tests +# ============================================================================ + + +def test_normalize_backward_compat_deprecated_warning(): + """Test that using means/stds without method raises deprecation warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + norm = dp.Normalize( + input_keys=["x"], + means={"x": 0.0}, + stds={"x": 1.0}, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "deprecated" in str(w[-1].message).lower() + + # Should still work correctly + sample = TensorDict({"x": torch.tensor([10.0])}) + result = norm(sample) + torch.testing.assert_close(result["x"], torch.tensor([10.0]), atol=1e-6, rtol=1e-6) + + +def test_normalize_backward_compat_state_dict(): + """Test loading old state_dict without method field.""" + old_state = { + "input_keys": ["x"], + "means": {"x": torch.tensor(5.0)}, + "stds": {"x": torch.tensor(2.0)}, + "eps": 1e-8, + } + + norm = dp.Normalize( + input_keys=["x"], method="mean_std", means={"x": 0.0}, stds={"x": 1.0} + ) + norm.load_state_dict(old_state) + + # Should default to mean_std method + assert norm.method == "mean_std" + + sample = TensorDict({"x": torch.tensor([7.0])}) + result = norm(sample) + # (7 - 5) / 2 = 1.0 + torch.testing.assert_close(result["x"], torch.tensor([1.0]), atol=1e-6, rtol=1e-6) + + +# ============================================================================ +# Method Validation Tests +# ============================================================================ + + +def test_normalize_invalid_method_raises(): + """Test that invalid method raises ValueError.""" + with pytest.raises(ValueError, match="must be 'mean_std' or 'min_max'"): + dp.Normalize( + input_keys=["x"], + method="invalid_method", + means={"x": 0.0}, + stds={"x": 1.0}, + ) + + +def test_normalize_minmax_missing_mins_raises(): + """Test that min_max method without mins raises ValueError.""" + with pytest.raises(ValueError, match="'mins' and 'maxs' must be provided"): + dp.Normalize( + input_keys=["x"], + method="min_max", + maxs={"x": 100.0}, + ) + + +def test_normalize_minmax_missing_maxs_raises(): + """Test that min_max method without maxs raises ValueError.""" + with pytest.raises(ValueError, match="'mins' and 'maxs' must be provided"): + dp.Normalize( + input_keys=["x"], + method="min_max", + mins={"x": 0.0}, + ) + + +def test_normalize_mean_std_missing_stds_raises(): + """Test that mean_std method without stds raises ValueError.""" + with pytest.raises(ValueError, match="'means' and 'stds' must be provided"): + dp.Normalize( + input_keys=["x"], + method="mean_std", + means={"x": 0.0}, + ) + + +# ============================================================================ +# Downsample transform - NOT YET IMPLEMENTED +# ============================================================================ + + +def test_subsample_basic(): + sample = TensorDict({"points": torch.randn(1000, 3)}) + ds = dp.SubsamplePoints(input_keys=["points"], n_points=100) + + result = ds(sample) + assert result["points"].shape == (100, 3) + + +def test_subsample_multiple_fields(): + sample = TensorDict( + { + "positions": torch.randn(500, 3), + "features": torch.randn(500, 8), + } + ) + ds = dp.SubsamplePoints(input_keys=["positions", "features"], n_points=100) + + result = ds(sample) + assert result["positions"].shape == (100, 3) + assert result["features"].shape == (100, 8) + + +def test_subsample_preserves_other_fields(): + sample = TensorDict( + { + "points": torch.randn(500, 3), + "label": torch.tensor([1]), + } + ) + ds = dp.SubsamplePoints(input_keys=["points"], n_points=100) + + result = ds(sample) + assert result["points"].shape == (100, 3) + torch.testing.assert_close(result["label"], torch.tensor([1])) + + +def test_subsample_no_op_when_smaller(): + sample = TensorDict({"x": torch.randn(50, 3)}) + ds = dp.SubsamplePoints(input_keys=["x"], n_points=100) + + result = ds(sample) + # Should return original since 50 < 100 + assert result["x"].shape == (50, 3) + + +def test_subsample_inconsistent_sizes_raises(): + sample = TensorDict( + { + "a": torch.randn(100, 3), + "b": torch.randn(200, 3), # Different size! + } + ) + ds = dp.SubsamplePoints(input_keys=["a", "b"], n_points=50) + + with pytest.raises(ValueError, match="same first dimension"): + ds(sample) + + +def test_subsample_missing_key_raises(): + sample = TensorDict({"x": torch.randn(100, 3)}) + ds = dp.SubsamplePoints(input_keys=["y"], n_points=50) + + with pytest.raises(KeyError): + ds(sample) + + +def test_subsample_weighted(): + sample = TensorDict( + { + "points": torch.randn(1000, 3), + "weights": torch.rand(1000), + } + ) + ds = dp.SubsamplePoints(input_keys=["points"], n_points=100, weights_key="weights") + + result = ds(sample) + assert result["points"].shape == (100, 3) + + +def test_subsample_weighted_missing_weights_raises(): + sample = TensorDict({"points": torch.randn(1000, 3)}) + ds = dp.SubsamplePoints( + input_keys=["points"], n_points=100, weights_key="missing_weights" + ) + + with pytest.raises(KeyError, match="missing_weights"): + ds(sample) + + +def test_subsample_poisson_algorithm(): + sample = TensorDict({"points": torch.randn(1000, 3)}) + ds = dp.SubsamplePoints( + input_keys=["points"], n_points=100, algorithm="poisson_fixed" + ) + + result = ds(sample) + assert result["points"].shape == (100, 3) + + +def test_subsample_uniform_algorithm(): + sample = TensorDict({"points": torch.randn(1000, 3)}) + ds = dp.SubsamplePoints(input_keys=["points"], n_points=100, algorithm="uniform") + + result = ds(sample) + assert result["points"].shape == (100, 3) + + +def test_subsample_repr(): + ds = dp.SubsamplePoints(input_keys=["x"], n_points=100) + assert "SubsamplePoints" in repr(ds) + assert "100" in repr(ds) + + +def test_subsample_repr_with_weights(): + ds = dp.SubsamplePoints(input_keys=["x"], n_points=100, weights_key="areas") + assert "SubsamplePoints" in repr(ds) + assert "weights_key=areas" in repr(ds) + + +# TODO: Implement Downsample transform +# def test_downsample_basic(): +# sample = Sample({"points": torch.randn(1000, 3)}) +# ds = dp.Downsample(input_keys=["points"], n=100) +# +# result = ds(sample) +# assert result["points"].shape == (100, 3) +# +# +# def test_downsample_multiple_fields(): +# sample = Sample( +# { +# "positions": torch.randn(500, 3), +# "features": torch.randn(500, 8), +# } +# ) +# ds = dp.Downsample(input_keys=["positions", "features"], n=100) +# +# result = ds(sample) +# assert result["positions"].shape == (100, 3) +# assert result["features"].shape == (100, 8) +# +# +# def test_downsample_preserves_other_fields(): +# sample = Sample( +# { +# "points": torch.randn(500, 3), +# "label": torch.tensor([1]), +# } +# ) +# ds = dp.Downsample(input_keys=["points"], n=100) +# +# result = ds(sample) +# assert result["points"].shape == (100, 3) +# torch.testing.assert_close(result["label"], torch.tensor([1])) +# +# +# def test_downsample_seed_reproducibility(): +# sample = Sample({"x": torch.randn(1000)}) +# ds1 = dp.Downsample(input_keys=["x"], n=100, seed=42) +# ds2 = dp.Downsample(input_keys=["x"], n=100, seed=42) +# +# result1 = ds1(sample) +# result2 = ds2(sample) +# +# torch.testing.assert_close(result1["x"], result2["x"]) +# +# +# def test_downsample_no_op_when_smaller(): +# sample = Sample({"x": torch.randn(50, 3)}) +# ds = dp.Downsample(input_keys=["x"], n=100, replacement=False) +# +# result = ds(sample) +# # Should return original since 50 < 100 and no replacement +# assert result["x"].shape == (50, 3) +# +# +# def test_downsample_with_replacement(): +# sample = Sample({"x": torch.randn(50, 3)}) +# ds = dp.Downsample(input_keys=["x"], n=100, replacement=True) +# +# result = ds(sample) +# # With replacement, can upsample +# assert result["x"].shape == (100, 3) +# +# +# def test_downsample_different_axis(): +# # Shape: (3, 1000) - downsample along axis 1 +# sample = Sample({"x": torch.randn(3, 1000)}) +# ds = dp.Downsample(input_keys=["x"], n=100, axis=1) +# +# result = ds(sample) +# assert result["x"].shape == (3, 100) +# +# +# def test_downsample_inconsistent_sizes_raises(): +# sample = Sample( +# { +# "a": torch.randn(100, 3), +# "b": torch.randn(200, 3), # Different size! +# } +# ) +# ds = dp.Downsample(input_keys=["a", "b"], n=50) +# +# with pytest.raises(ValueError, match="has size"): +# ds(sample) +# +# +# def test_downsample_empty_fields_raises(): +# with pytest.raises(ValueError, match="cannot be empty"): +# dp.Downsample(input_keys=[], n=100) +# +# +# def test_downsample_invalid_n_raises(): +# with pytest.raises(ValueError, match="must be >= 1"): +# dp.Downsample(input_keys=["x"], n=0) + + +# ============================================================================ +# Compose transform +# ============================================================================ + + +def test_compose_single_transform(): + sample = TensorDict({"x": torch.tensor([10.0])}) + norm = dp.Normalize( + input_keys=["x"], method="mean_std", means={"x": 10.0}, stds={"x": 1.0} + ) + pipeline = dp.Compose([norm]) + + result = pipeline(sample) + torch.testing.assert_close(result["x"], torch.tensor([0.0]), atol=1e-6, rtol=1e-6) + + +def test_compose_order_matters(): + sample = TensorDict({"x": torch.tensor([100.0, 200.0, 300.0])}) + + # Normalize then check values + norm = dp.Normalize( + input_keys=["x"], method="mean_std", means={"x": 200.0}, stds={"x": 100.0} + ) + pipeline = dp.Compose([norm]) + + result = pipeline(sample) + expected = torch.tensor([-1.0, 0.0, 1.0]) + torch.testing.assert_close(result["x"], expected, atol=1e-6, rtol=1e-6) + + +def test_compose_len(): + pipeline = dp.Compose( + [ + dp.Normalize( + input_keys=["x"], method="mean_std", means={"x": 0.0}, stds={"x": 1.0} + ), + # dp.Downsample(input_keys=["x"], n=10), # Not implemented yet + ] + ) + assert len(pipeline) == 1 + + +def test_compose_getitem(): + norm = dp.Normalize( + input_keys=["x"], method="mean_std", means={"x": 0.0}, stds={"x": 1.0} + ) + # ds = dp.Downsample(input_keys=["x"], n=10) # Not implemented yet + pipeline = dp.Compose([norm]) + + assert pipeline[0] is norm + # assert pipeline[1] is ds + + +def test_compose_iteration(): + transforms = [ + dp.Normalize( + input_keys=["x"], method="mean_std", means={"x": 0.0}, stds={"x": 1.0} + ), + # dp.Downsample(input_keys=["x"], n=10), # Not implemented yet + ] + pipeline = dp.Compose(transforms) + + for i, t in enumerate(pipeline): + assert t is transforms[i] + + +def test_compose_empty_raises(): + with pytest.raises(ValueError, match="cannot be empty"): + dp.Compose([]) + + +def test_compose_non_transform_raises(): + with pytest.raises(TypeError, match="must be Transform"): + dp.Compose([lambda x: x]) + + +# ============================================================================ +# Transform repr +# ============================================================================ + + +def test_normalize_repr(): + norm = dp.Normalize( + input_keys=["x"], method="mean_std", means={"x": 0.0}, stds={"x": 1.0} + ) + assert "Normalize" in repr(norm) + assert "mean_std" in repr(norm) + + +# def test_downsample_repr(): +# ds = dp.Downsample(input_keys=["x"], n=100) +# assert "Downsample" in repr(ds) +# assert "100" in repr(ds) + + +def test_compose_repr(): + pipeline = dp.Compose( + [ + dp.Normalize( + input_keys=["x"], method="mean_std", means={"x": 0.0}, stds={"x": 1.0} + ), + ] + ) + assert "Compose" in repr(pipeline) + assert "Normalize" in repr(pipeline) diff --git a/test/datapipes/core/transforms/test_field_slice.py b/test/datapipes/core/transforms/test_field_slice.py new file mode 100644 index 0000000000..f476849596 --- /dev/null +++ b/test/datapipes/core/transforms/test_field_slice.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for FieldSlice transform.""" + +import pytest +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes.core.transforms import FieldSlice + + +def test_index_selection_last_dim(): + """Test selecting specific indices from the last dimension.""" + n_points = 100 + n_features = 10 + selected_indices = [0, 2, 5] + + # Create sample + features = torch.randn(n_points, n_features) + sample = TensorDict({"features": features, "coords": torch.randn(n_points, 3)}) + + # Create transform + transform = FieldSlice({"features": {-1: selected_indices}}) + + # Apply + result = transform(sample) + + # Check shape + assert result["features"].shape == (n_points, len(selected_indices)) + + # Check values + expected = features[:, selected_indices] + assert torch.allclose(result["features"], expected) + + # Unchanged field + assert result["coords"].shape == (n_points, 3) + + +def test_index_selection_first_dim(): + """Test selecting specific indices from the first dimension.""" + data = torch.randn(10, 8, 6) + sample = TensorDict({"data": data}) + + transform = FieldSlice({"data": {0: [1, 3, 5]}}) + result = transform(sample) + + assert result["data"].shape == (3, 8, 6) + expected = data[[1, 3, 5], :, :] + assert torch.allclose(result["data"], expected) + + +def test_slice_selection(): + """Test selecting a slice (start:stop:step).""" + data = torch.randn(100, 10) + sample = TensorDict({"data": data}) + + # Select first 5 elements of last dimension + transform = FieldSlice({"data": {-1: {"start": 0, "stop": 5}}}) + result = transform(sample) + + assert result["data"].shape == (100, 5) + expected = data[:, 0:5] + assert torch.allclose(result["data"], expected) + + +def test_slice_with_step(): + """Test slice with step.""" + data = torch.randn(100, 10) + sample = TensorDict({"data": data}) + + # Select every other element: [0, 2, 4, 6, 8] + transform = FieldSlice({"data": {-1: {"start": 0, "stop": 10, "step": 2}}}) + result = transform(sample) + + assert result["data"].shape == (100, 5) + expected = data[:, 0:10:2] + assert torch.allclose(result["data"], expected) + + +def test_multiple_dimensions(): + """Test slicing multiple dimensions of a single field.""" + data = torch.randn(10, 8, 6) + sample = TensorDict({"data": data}) + + # Slice dim 0 (indices 1, 3) and dim 2 (first 3) + transform = FieldSlice( + { + "data": { + 0: [1, 3], + 2: {"stop": 3}, + } + } + ) + result = transform(sample) + + assert result["data"].shape == (2, 8, 3) + expected = data[[1, 3], :, :][:, :, :3] + assert torch.allclose(result["data"], expected) + + +def test_multiple_fields(): + """Test slicing multiple fields.""" + features = torch.randn(100, 10) + velocity = torch.randn(100, 3) + sample = TensorDict({"features": features, "velocity": velocity}) + + transform = FieldSlice( + { + "features": {-1: [0, 2, 5]}, + "velocity": {-1: [0, 1]}, # Keep only x, y + } + ) + result = transform(sample) + + assert result["features"].shape == (100, 3) + assert result["velocity"].shape == (100, 2) + + +def test_string_keys_for_hydra(): + """Test that string dimension keys work (for Hydra YAML).""" + data = torch.randn(100, 10) + sample = TensorDict({"data": data}) + + # Use string key "-1" like Hydra would pass + transform = FieldSlice({"data": {"-1": [0, 2, 5]}}) + result = transform(sample) + + assert result["data"].shape == (100, 3) + + +def test_missing_field_raises(): + """Test that missing field raises KeyError.""" + sample = TensorDict({"data": torch.randn(10, 10)}) + + transform = FieldSlice({"missing": {-1: [0]}}) + + with pytest.raises(KeyError, match="missing"): + transform(sample) diff --git a/test/datapipes/core/transforms/test_subsample.py b/test/datapipes/core/transforms/test_subsample.py new file mode 100644 index 0000000000..6e72ef3712 --- /dev/null +++ b/test/datapipes/core/transforms/test_subsample.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Tests for subsampling transforms.""" + +import pytest +import torch +from tensordict import TensorDict + +from physicsnemo.datapipes.core.transforms import ( + SubsamplePoints, + poisson_sample_indices_fixed, +) + + +def test_poisson_sample_indices(): + """Test Poisson sampling indices generation.""" + N = 10000 + k = 1000 + + indices = poisson_sample_indices_fixed(N, k) + + assert indices.shape == (k,) + assert indices.min() >= 0 + assert indices.max() < N + assert indices.dtype == torch.long + + +def test_poisson_sample_large_array(): + """Test Poisson sampling with very large arrays.""" + N = 100_000_000 # 100M points + k = 10000 + + indices = poisson_sample_indices_fixed(N, k) + + assert indices.shape == (k,) + assert indices.min() >= 0 + assert indices.max() < N + + +def test_subsample_points_basic(): + """Test basic point subsampling.""" + transform = SubsamplePoints( + input_keys=["coords", "fields"], + n_points=100, + algorithm="uniform", + ) + + sample = TensorDict( + { + "coords": torch.randn(1000, 3), + "fields": torch.randn(1000, 4), + } + ) + + result = transform(sample) + + assert result["coords"].shape == (100, 3) + assert result["fields"].shape == (100, 4) + + +def test_subsample_points_coordinated(): + """Test that same indices are applied to all input_keys.""" + transform = SubsamplePoints( + input_keys=["coords", "fields"], + n_points=100, + algorithm="uniform", + ) + + # Create data where indices can be verified + coords = torch.arange(1000).unsqueeze(-1).expand(-1, 3).float() + fields = torch.arange(1000).unsqueeze(-1).expand(-1, 4).float() + + sample = TensorDict( + { + "coords": coords, + "fields": fields, + } + ) + + result = transform(sample) + + # First column of coords and fields should match + assert torch.allclose(result["coords"][:, 0], result["fields"][:, 0]) + + +def test_subsample_points_skip_small(): + """Test that subsampling is skipped if already small enough.""" + transform = SubsamplePoints( + input_keys=["coords"], + n_points=1000, + ) + + coords = torch.randn(500, 3) + sample = TensorDict({"coords": coords}) + + result = transform(sample) + + # Should return original data unchanged + assert torch.equal(result["coords"], coords) + + +def test_subsample_points_weighted(): + """Test weighted sampling with weights_key parameter.""" + transform = SubsamplePoints( + input_keys=["surface_coords", "surface_fields"], + n_points=100, + algorithm="uniform", + weights_key="surface_areas", + ) + + # Create sample with + # areas (larger areas should be sampled more) + sample = TensorDict( + { + "surface_coords": torch.randn(1000, 3), + "surface_fields": torch.randn(1000, 2), + "surface_areas": torch.rand(1000), + } + ) + + result = transform(sample) + + assert result["surface_coords"].shape == (100, 3) + assert result["surface_fields"].shape == (100, 2) + + +def test_subsample_missing_weights_key(): + """Test that error is raised if weights key is missing.""" + transform = SubsamplePoints( + input_keys=["surface_coords"], + n_points=100, + algorithm="uniform", + weights_key="surface_areas", + ) + + sample = TensorDict( + { + "surface_coords": torch.randn(1000, 3), + # Missing surface_areas + } + ) + + with pytest.raises(KeyError, match="Weights key"): + transform(sample) + + +def test_subsample_device_preservation(): + """Test that subsampling preserves device.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + transform = SubsamplePoints( + input_keys=["coords"], + n_points=100, + ) + + sample = TensorDict( + { + "coords": torch.randn(1000, 3, device="cuda"), + } + ) + + result = transform(sample) + + assert result["coords"].device.type == "cuda" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/v2.0-MIGRATION-GUIDE.md b/v2.0-MIGRATION-GUIDE.md index ad558e5605..25a0d0fd2f 100644 --- a/v2.0-MIGRATION-GUIDE.md +++ b/v2.0-MIGRATION-GUIDE.md @@ -36,6 +36,8 @@ Several new packages have been introduced for PhysicsNeMo v2.0. At a high level 2. `physicsnemo.domain_parallel` contains the `ShardTensor` object and utilities. 3. [TBD] `physicsnemo.diffusion` 4. [TBD] `physicsnemo.mesh` +5. `physicsnemo.datapipes` brings reusable and generic utilities to standardize +GPU-centric data pipelines for SciML. ## Packaging, Installation, and Dependencies @@ -53,7 +55,13 @@ is still a viable installation method. ## PhysicsNeMo Datapipes -[TBD] +PhysicsNeMo DataPipes is a GPU-first, high-performance data loading +infrastructure for scientific machine learning that uses threading and +asynchronous execution to maximize throughput on large, high-resolution +datasets. It provides a modular architecture of readers, transforms, datasets, +and a dataloader that can be configured via Hydra YAML files for reproducibility, +while maintaining familiar PyTorch-like interfaces and easy extensibility +for custom data formats and preprocessing operations. ## Updating your code