diff --git a/CHANGELOG.md b/CHANGELOG.md index 22a6da8b41..a9226ada3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added example for 2D urban flow generation using diffusion models: + - Complete training/generation/evaluation pipeline for EDM-based + turbulent flow synthesis - Added mixture_of_experts for weather example in physicsnemo.examples.weather. **⚠️Warning:** - It uses experimental DiT model subject to future API changes. Added some modifications to DiT architecture in physicsnemo.experimental.models.dit. diff --git a/examples/weather/diffusion-urban-flows-2D/README.md b/examples/weather/diffusion-urban-flows-2D/README.md new file mode 100644 index 0000000000..cf87a0bb9f --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/README.md @@ -0,0 +1,525 @@ +# Diffusion Models for 2D Urban Flow Generation + +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +[![Python](https://img.shields.io/badge/Python-3.10%2B-blue)](https://www.python.org/downloads/) + +This example demonstrates unconditional generation of 2D urban turbulent flow +fields using diffusion models within the PhysicsNeMo framework. It is based on +the methodology from: + +**Diff-SPORT: Diffusion-based Sensor Placement Optimization and Reconstruction +of Turbulent flows in urban environments** +*Abhijeet Vishwasrao, Sai Bharath Chandra Gutha, Andres Cremades, Klas Wijk, +Aakash Patil, Catherine Gorle, Beverley J McKeon, Hossein Azizpour, +Ricardo Vinuesa* + +> **Paper:** [arXiv:2506.00214](https://arxiv.org/abs/2506.00214) + +--- + +## Problem Overview + +Urban turbulence monitoring is critical for air quality assessment, climate +resilience, and infrastructure design. Traditional computational fluid dynamics +(CFD) approaches are computationally expensive, while sparse sensor networks +often fail to capture the full complexity of turbulent flows. + +**Diffusion models** offer a data-driven alternative that can: + +- Generate high-fidelity 2D velocity fields at a fraction of the + computational cost +- Capture statistical properties of turbulence (Reynolds stresses, + spectral content) +- Provide unconditional samples for ensemble-based analysis +- Achieve significant speedups compared to traditional numerical methods + +This example trains an **EDM (Elucidating the Design Space of Diffusion-Based +Generative Models)** preconditioned diffusion model on 2D urban flow data and +demonstrates its ability to generate statistically accurate turbulent velocity +fields. + +**Key Results:** + +- High-fidelity reconstruction of Reynolds stress statistics +- Accurate joint probability density functions (JPDFs) matching ground truth +- Visually realistic instantaneous flow fields +- Results available in: [`results/uncond_eval/epoch-1100/`](results/uncond_eval/epoch-1100/) + +--- + +## Getting Started + +### Prerequisites + +1. **PhysicsNeMo Installation:** + Follow the + [PhysicsNeMo installation guide](https://github.com/NVIDIA/physicsnemo) + to install the framework. + +2. **Additional Dependencies:** + + ```bash + pip install h5py scipy matplotlib tqdm omegaconf hydra-core + ``` + +3. **Hardware Requirements:** + - **Training:** 1-4 GPUs (NVIDIA A100 recommended) + - **Inference:** 1 GPU + - **Memory:** ~16 GB GPU memory per device + +--- + +## Dataset + +### Urban Flow Data (OneObs2D) + +The training dataset consists of 2D velocity fields extracted from a horizontal +plane (z = 0) of a 3D direct numerical simulation (DNS) of turbulent flow +around a wall-mounted square cylinder. The data captures complex turbulent flow +patterns including separated flows, vortex shedding, and wake dynamics +characteristic of urban canopy flows. + +The 3D DNS dataset is described in detail in: + +> **Reference:** Martínez-Sánchez Á, López E, Le Clainche S, Lozano-Durán A, +> Srivastava A, Vinuesa R(2023). Causality analysis of large-scale structures in +> the flow around a wall-mounted square cylinder. +> *Journal of Fluid Mechanics*, 758, 252-272. +> [DOI: 10.1017/jfm.2014.544](https://www.cambridge.org/core/journals/journal-of-fluid-mechanics/article/causality-analysis-of-largescale-structures-in-the-flow-around-a-wallmounted-square-cylinder/052D6C4235154130B14E336B0F7B9E13) + +**Data Specifications:** + +- **Format:** HDF5 (`.h5` file) +- **Channels:** 2 (u-velocity, v-velocity fluctuations) +- **Spatial Domain:** + - X-axis: [-1.0, 4.74] (288 grid points) + - Y-axis: [0.0, 1.9] (96 grid points) +- **Resolution:** 288 × 96 pixels +- **Training Samples:** ~25,000 snapshots + +**Data Structure:** + +```python +# HDF5 file contents +{ + 'u_fluc': (N, 288, 96), # U-velocity fluctuations + 'v_fluc': (N, 288, 96), # V-velocity fluctuations + 'x': (288,), # X-coordinates + 'y': (96,), # Y-coordinates + 't': (N,), # Time snapshots + 'means': (2, 288, 96) # Mean velocity field +} +``` + +> **Note:** The dataset path in the config should point to your local HDF5 +> file. Update `conf/dataset/uflow2d.yaml` with the correct path. + +--- + +## Model Architecture + +### EDM-Based Diffusion Model + +The model combines two key components: + +1. **EDMPrecond (Preconditioning Wrapper)** + - Implements the EDM framework for improved training and sampling + - Handles noise level conditioning and scaling + - Provides σ-dependent input/output transformations + +2. **SongUNet (Denoising Network)** + - U-Net architecture with self-attention + - Residual blocks with adaptive normalization + - Multi-resolution feature extraction + +**Model Hyperparameters:** + +| Parameter | Value | Description | +|-----------|-------|-------------| +| `model_channels` | 64 | Base channel count | +| `channel_mult` | [1, 2, 2, 2] | Channel multipliers per level | +| `attn_resolutions` | [4, 8] | Resolutions with self-attention | +| `num_blocks` | 2 | Residual blocks per level | +| `dropout` | 0.0 | Dropout probability | +| `channel_mult_emb` | 4 | Time embedding dimension multiplier | + +**Total Parameters:** ~10.0M trainable parameters + +**Model Configuration:** +See [`conf/model/diffusion_uflow.yaml`](conf/model/diffusion_uflow.yaml) for +detailed architecture settings. + +--- + +## Training + +### Training Configuration + +The model is trained using the EDM loss function, which optimizes the denoising +objective across multiple noise levels. + +**Training Hyperparameters:** + +| Parameter | Value | Description | +|-----------|-------|-------------| +| **Epochs** | 1000 | Total training epochs | +| **Batch Size** | 64 per GPU | Effective batch size scales with GPUs | +| **Learning Rate** | 1e-3 | Initial learning rate | +| **LR Schedule** | Decay from epoch 100 | Learning rate decay factor | +| **Optimizer** | Adam | Default optimizer | +| **Precision** | FP32 | Mixed precision optional | +| **Checkpoint Frequency** | Every 100 epochs | Model checkpointing | + +**Training Configuration:** +See [`conf/training/diffusion_uflow.yaml`](conf/training/diffusion_uflow.yaml) +for full training settings. + +### Single GPU Training + +```bash +python train.py --config-name=config_training_uflow +``` + +**Expected Training Time:** + +- ~18-24 hours on NVIDIA A100 (40GB) + +### Multi-GPU Distributed Training + +Leverage multiple GPUs (multi-node) for faster training: + +```bash +# 8 GPUs +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone \ + --nnodes=1 --nproc_per_node=8 train.py +``` + +### Checkpointing + +Checkpoints are saved every 100 epochs to: + +```text +outputs/diffusion_uflow/checkpoints/epoch_.pt +``` + +**Checkpoint Contents:** + +- Model state dict +- Optimizer state dict +- Training epoch +- Configuration snapshot + +To resume training from a checkpoint: + +```bash +python train.py --config-name=config_training_uflow \ + ++training.io.resume_checkpoint=outputs/diffusion_uflow/checkpoints/epoch_1000.pt +``` + +### Monitoring Training + +Training progress is logged to TensorBoard: + +```bash +tensorboard --logdir outputs/diffusion_uflow/tensorboard +``` + +**Logged Metrics:** + +- Training loss (EDM loss) +- Learning rate +- Gradient norms +- Memory usage + +--- + +## Generation + +### Unconditional Sampling + +Generate synthetic flow fields using the trained diffusion model through +iterative denoising. + +**Generation Command:** + +```bash +python generate.py --config-name=config_generate_uflow \ + ++generation.io.inf_ckpt=1100 \ + ++generation.total_images=1000 \ + ++generation.batch_size_total=50 +``` + +**Key Parameters:** + +- `generation.io.inf_ckpt`: Checkpoint epoch to load +- `generation.total_images`: Number of samples to generate +- `generation.batch_size_total`: Batch size for generation + +### Sampling Configuration + +The sampler uses the EDM framework with the following settings: + +| Parameter | Value | Description | +|-----------|-------|-------------| +| **Solver** | Euler | Integration method (Euler or Heun) | +| **Discretization** | EDM | Noise schedule type | +| **Schedule** | Linear | Time discretization | +| **Steps** | 1000 | Number of denoising steps | +| **Rho** | 1 | Schedule parameter | + +**Sampling Configuration:** +See [`conf/generation/uflow2d.yaml`](conf/generation/uflow2d.yaml) for +generation settings. + +### Sampling Methods + +**Euler Solver (Default):** + +- First-order method +- Faster sampling +- Recommended for quick iterations + +```bash +python generate.py ++generation.sampler.solver=euler \ + ++generation.sampler.num_steps=1000 +``` + +**Heun Solver (Higher Quality):** + +- Second-order method +- Slower but more accurate +- Better for final results + +```bash +python generate.py ++generation.sampler.solver=heun \ + ++generation.sampler.num_steps=1000 +``` + +### Output Format + +Generated samples are saved as HDF5 files: + +```text +outputs/diffusion_uflow/generated/pred_snaps-.h5 +``` + +**File Structure:** + +```python +{ + 'u_pred': (N, 288, 96), # Generated U-velocity + 'v_pred': (N, 288, 96), # Generated V-velocity + 'x': (288,), # X-coordinates + 'y': (96,) # Y-coordinates +} +``` + +Velocities are denormalized to physical units. + +--- + +## Evaluation + +### Statistical Metrics + +The evaluation script computes comprehensive turbulence statistics to assess +the quality of generated flows. + +**Metrics Computed:** + +1. **Reynolds Stress Statistics** + - Normal stresses: ⟨u'u'⟩, ⟨v'v'⟩ + - Shear stress: ⟨u'v'⟩ + - Spatial profiles along X and Y axes + +2. **Joint Probability Density Functions (JPDFs)** + - 2D histograms of velocity components + - Captures correlation structure + - Comparison: ground truth vs. generated + +3. **Visual Field Comparisons** + - Instantaneous flow snapshots + - Side-by-side comparisons + - Error/difference maps + +### Running Evaluation + +```bash +python evaluate-uncond-gen-2D.py --config-name=config_generate_uflow +``` + +**Output:** + +- Multi-page PDF reports in `results/uncond_eval/epoch-/` +- PNG figures for individual metrics + +**Evaluation Configuration:** +See [`conf/evaluate/uflow2d_eval.yaml`](conf/evaluate/uflow2d_eval.yaml) for +evaluation settings. + +--- + +## Results Showcase + +### Visual Flow Field Comparison + +Instantaneous velocity snapshots show realistic turbulent structures: + +![Visual Comparison](results/uncond_eval/epoch-1100/pred_snaps-1000-visual_comparison_num-0.png) + +> **Figure 1:** Side-by-side comparison of unconditional instantaneous +> flow fields (right two columns) with ground truth (left column). +> Top: stream-wise velocity component (u'). Bottom: wall-normal velocity +> component (v'). The model captures vortical structures, shear layers, +> and fine-scale turbulence. + +**Key Observations:** + +- ✓ Vortex structures realistic and diverse +- ✓ Spatial scales consistent with training data +- ✓ No visible artifacts or unphysical patterns + +### Reynolds Stress Statistics + +The generated flows accurately reproduce the Reynolds stress components of the +training data: + +![Reynolds Stresses](results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses1.png) + +> **Figure 2:** Comparison of Reynolds normal stresses (⟨u'u'⟩, ⟨v'v'⟩) between +> ground truth (training data) and generated samples. Spatial profiles +> demonstrate excellent statistical agreement. + +![Reynolds Shear Stress](results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses2.png) + +> **Figure 3:** Reynolds shear stress (⟨u'v'⟩) spatial distribution. The model +> captures the correlation structure between velocity components. + +**Key Observations:** + +- ✓ Mean stress profiles match within 5% error +- ✓ Peak locations and magnitudes preserved +- ✓ Spatial coherence maintained + +### Joint Probability Density Functions (JPDFs) + +The velocity component distributions and correlations are accurately captured: + +![JPDFs Comparison 1](results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-0.png) + +> **Figure 4:** Joint PDF of (u') velocity fluctuation component at y/h = +> 0.5. + +![JPDFs Comparison 2](results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-1.png) + +> **Figure 5:** Joint PDF of (v') velocity fluctuation component at y/h = +> 0.5. + +**Key Observations:** + +- ✓ Probability contours align between ground truth and generated +- ✓ Variance and covariance structure preserved +- ✓ No mode collapse or artificial biases + +--- + +## Configuration Details + +### Hydra Configuration System + +This example uses [Hydra](https://hydra.cc/) for hierarchical configuration +management, allowing flexible parameter overrides without modifying code. + +**Configuration Structure:** + +```text +conf/ +├── config_training_uflow.yaml # Main training config +├── config_generate_uflow.yaml # Main generation config +├── dataset/ +│ └── uflow2d.yaml # Dataset parameters +├── model/ +│ └── diffusion_uflow.yaml # Model architecture +├── training/ +│ └── diffusion_uflow.yaml # Training hyperparameters +├── generation/ +│ └── uflow2d.yaml # Generation settings +└── evaluate/ + └── uflow2d_eval.yaml # Evaluation configuration +``` + +### Key Configuration Files + +**[conf/config_training_uflow.yaml](conf/config_training_uflow.yaml)** +Main training configuration with references to sub-configs. + +**[conf/dataset/uflow2d.yaml](conf/dataset/uflow2d.yaml)** +Dataset path, normalization parameters, spatial axes. + +**[conf/model/diffusion_uflow.yaml](conf/model/diffusion_uflow.yaml)** +Model architecture (channels, attention, blocks). + +**[conf/training/diffusion_uflow.yaml](conf/training/diffusion_uflow.yaml)** +Training hyperparameters (epochs, batch size, learning rate). + +**[conf/generation/uflow2d.yaml](conf/generation/uflow2d.yaml)** +Sampling configuration (solver, steps, output paths). + +--- + +## References + +### Papers + +1. **Diff-SPORT Paper:** + Vishwasrao, A., et al. "Diffusion-based Sensor Placement Optimization and + Reconstruction of Turbulent flows in urban environments." arXiv preprint + (2024). [arXiv:2506.00214](https://arxiv.org/abs/2506.00214) + +2. **EDM Framework:** + Karras, T., Aittala, M., Aila, T., & Laine, S. "Elucidating the Design + Space of Diffusion-Based Generative Models." *Advances in Neural + Information Processing Systems*, 35, pp. 26565-26577 (2022). + [arXiv:2206.00364](https://arxiv.org/abs/2206.00364) + +3. **Score-Based Generative Models:** + Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & + Poole, B. "Score-Based Generative Modeling through Stochastic Differential + Equations." *ICLR* (2021). [arXiv:2011.13456](https://arxiv.org/abs/2011.13456) + +### PhysicsNeMo Documentation + +- **Main Repository:** + [https://github.com/NVIDIA/physicsnemo](https://github.com/NVIDIA/physicsnemo) +- **Documentation:** + [https://docs.nvidia.com/physicsnemo](https://docs.nvidia.com/physicsnemo) +- **Diffusion Models API:** + [PhysicsNeMo Diffusion Module](https://docs.nvidia.com/physicsnemo/models/diffusion.html) + +--- + +## Citation + +If you use this code or methodology in your research, please cite: + +```bibtex +@article{vishwasrao2024diffsport, + title={Diff-SPORT: Diffusion-based Sensor Placement Optimization and + Reconstruction of Turbulent flows in urban environments}, + author={Vishwasrao, Abhijeet and Gutha, Sai Bharath Chandra and + Cremades, Andres and Wijk, Klas and Patil, Aakash and + Gorle, Catherine and McKeon, Beverley J and Azizpour, Hossein + and Vinuesa, Ricardo}, + journal={arXiv preprint arXiv:2506.00214}, + year={2024} +} +``` + +--- + +## License + +This project is licensed under the Apache License 2.0 - see the +[LICENSE](../../LICENSE.txt) file for details. + +--- diff --git a/examples/weather/diffusion-urban-flows-2D/conf/config_generate_uflow.yaml b/examples/weather/diffusion-urban-flows-2D/conf/config_generate_uflow.yaml new file mode 100644 index 0000000000..ffd7931ac0 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/conf/config_generate_uflow.yaml @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +hydra: + job: + chdir: true + name: diffusion_uflow + run: + dir: /opt/Diffusion-PiGDM-Modulus-2D/outputs/${hydra:job.name} + +# Get defaults +defaults: + + # Dataset + - dataset/uflow2d + + # Sampler + #- sampler/stochastic + #- sampler/deterministic + + # Generation + - generation/uflow2d + #- generation/patched_based + + #Evaluation + - evaluate/uflow2d_eval diff --git a/examples/weather/diffusion-urban-flows-2D/conf/config_training_uflow.yaml b/examples/weather/diffusion-urban-flows-2D/conf/config_training_uflow.yaml new file mode 100644 index 0000000000..759cec9db5 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/conf/config_training_uflow.yaml @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +hydra: + job: + chdir: true + name: diffusion_uflow + run: + dir: /opt/Diffusion-PiGDM-Modulus-2D/outputs/${hydra:job.name} + +# Get defaults +defaults: + + # Dataset + - dataset/uflow2d + + # Model + - model/diffusion_uflow + + # Training + - training/diffusion_uflow diff --git a/examples/weather/diffusion-urban-flows-2D/conf/dataset/axis_data.yaml b/examples/weather/diffusion-urban-flows-2D/conf/dataset/axis_data.yaml new file mode 100644 index 0000000000..754055ef12 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/conf/dataset/axis_data.yaml @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +x: [ + -1.0, -0.98, -0.96, -0.94, -0.92, -0.9, -0.88, -0.86, -0.84, -0.82, + -0.8, -0.78, -0.76, -0.74, -0.72, -0.7, -0.68, -0.66, -0.64, -0.62, + -0.6, -0.58, -0.56, -0.54, -0.52, -0.5, -0.48, -0.46, -0.44, -0.42, + -0.4, -0.38, -0.36, -0.34, -0.32, -0.3, -0.28, -0.26, -0.24, -0.22, + -0.2, -0.18, -0.16, -0.14, -0.12, -0.1, -0.08, -0.06, -0.04, -0.02, + 0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, + 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, + 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, + 0.6, 0.62, 0.64, 0.66, 0.68, 0.7, 0.72, 0.74, 0.76, 0.78, + 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, + 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, + 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, + 1.4, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, + 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, + 1.8, 1.82, 1.84, 1.86, 1.88, 1.9, 1.92, 1.94, 1.96, 1.98, + 2.0, 2.02, 2.04, 2.06, 2.08, 2.1, 2.12, 2.14, 2.16, 2.18, + 2.2, 2.22, 2.24, 2.26, 2.28, 2.3, 2.32, 2.34, 2.36, 2.38, + 2.4, 2.42, 2.44, 2.46, 2.48, 2.5, 2.52, 2.54, 2.56, 2.58, + 2.6, 2.62, 2.64, 2.66, 2.68, 2.7, 2.72, 2.74, 2.76, 2.78, + 2.8, 2.82, 2.84, 2.86, 2.88, 2.9, 2.92, 2.94, 2.96, 2.98, + 3.0, 3.02, 3.04, 3.06, 3.08, 3.1, 3.12, 3.14, 3.16, 3.18, + 3.2, 3.22, 3.24, 3.26, 3.28, 3.3, 3.32, 3.34, 3.36, 3.38, + 3.4, 3.42, 3.44, 3.46, 3.48, 3.5, 3.52, 3.54, 3.56, 3.58, + 3.6, 3.62, 3.64, 3.66, 3.68, 3.7, 3.72, 3.74, 3.76, 3.78, + 3.8, 3.82, 3.84, 3.86, 3.88, 3.9, 3.92, 3.94, 3.96, 3.98, + 4.0, 4.02, 4.04, 4.06, 4.08, 4.1, 4.12, 4.14, 4.16, 4.18, + 4.2, 4.22, 4.24, 4.26, 4.28, 4.3, 4.32, 4.34, 4.36, 4.38, + 4.4, 4.42, 4.44, 4.46, 4.48, 4.5, 4.52, 4.54, 4.56, 4.58, + 4.6, 4.62, 4.64, 4.66, 4.68, 4.7, 4.72, 4.74 ] + +y: [ + 0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, + 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, + 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, + 0.6, 0.62, 0.64, 0.66, 0.68, 0.7, 0.72, 0.74, 0.76, 0.78, + 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, + 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, + 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, + 1.4, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, + 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, + 1.8, 1.82, 1.84, 1.86, 1.88, 1.9 +] diff --git a/examples/weather/diffusion-urban-flows-2D/conf/dataset/uflow2d.yaml b/examples/weather/diffusion-urban-flows-2D/conf/dataset/uflow2d.yaml new file mode 100644 index 0000000000..45a06326dc --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/conf/dataset/uflow2d.yaml @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +type: uflow2d +data_path: /opt/Diffusion-PiGDM-Modulus-2D/data/UFlow2d/OneObs2D-25k_z0_train-v2.h5 + +dataset_features: + + ds-ratio: 1 + u_comp: + min: -0.8485585124726626 + max: 0.7009188106686687 + v_comp: + min: -0.9421558538585597 + max: 1.0425998010281343 + + #Also add mean values + + axis: + x: [ + -1.0, -0.98, -0.96, -0.94, -0.92, -0.9, -0.88, -0.86, -0.84, -0.82, + -0.8, -0.78, -0.76, -0.74, -0.72, -0.7, -0.68, -0.66, -0.64, -0.62, + -0.6, -0.58, -0.56, -0.54, -0.52, -0.5, -0.48, -0.46, -0.44, -0.42, + -0.4, -0.38, -0.36, -0.34, -0.32, -0.3, -0.28, -0.26, -0.24, -0.22, + -0.2, -0.18, -0.16, -0.14, -0.12, -0.1, -0.08, -0.06, -0.04, -0.02, + 0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, + 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, + 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, + 0.6, 0.62, 0.64, 0.66, 0.68, 0.7, 0.72, 0.74, 0.76, 0.78, + 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, + 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, + 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, + 1.4, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, + 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, + 1.8, 1.82, 1.84, 1.86, 1.88, 1.9, 1.92, 1.94, 1.96, 1.98, + 2.0, 2.02, 2.04, 2.06, 2.08, 2.1, 2.12, 2.14, 2.16, 2.18, + 2.2, 2.22, 2.24, 2.26, 2.28, 2.3, 2.32, 2.34, 2.36, 2.38, + 2.4, 2.42, 2.44, 2.46, 2.48, 2.5, 2.52, 2.54, 2.56, 2.58, + 2.6, 2.62, 2.64, 2.66, 2.68, 2.7, 2.72, 2.74, 2.76, 2.78, + 2.8, 2.82, 2.84, 2.86, 2.88, 2.9, 2.92, 2.94, 2.96, 2.98, + 3.0, 3.02, 3.04, 3.06, 3.08, 3.1, 3.12, 3.14, 3.16, 3.18, + 3.2, 3.22, 3.24, 3.26, 3.28, 3.3, 3.32, 3.34, 3.36, 3.38, + 3.4, 3.42, 3.44, 3.46, 3.48, 3.5, 3.52, 3.54, 3.56, 3.58, + 3.6, 3.62, 3.64, 3.66, 3.68, 3.7, 3.72, 3.74, 3.76, 3.78, + 3.8, 3.82, 3.84, 3.86, 3.88, 3.9, 3.92, 3.94, 3.96, 3.98, + 4.0, 4.02, 4.04, 4.06, 4.08, 4.1, 4.12, 4.14, 4.16, 4.18, + 4.2, 4.22, 4.24, 4.26, 4.28, 4.3, 4.32, 4.34, 4.36, 4.38, + 4.4, 4.42, 4.44, 4.46, 4.48, 4.5, 4.52, 4.54, 4.56, 4.58, + 4.6, 4.62, 4.64, 4.66, 4.68, 4.7, 4.72, 4.74 ] + + y: [ + 0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, + 0.2, 0.22, 0.24, 0.26, 0.28, 0.3, 0.32, 0.34, 0.36, 0.38, + 0.4, 0.42, 0.44, 0.46, 0.48, 0.5, 0.52, 0.54, 0.56, 0.58, + 0.6, 0.62, 0.64, 0.66, 0.68, 0.7, 0.72, 0.74, 0.76, 0.78, + 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96, 0.98, + 1.0, 1.02, 1.04, 1.06, 1.08, 1.1, 1.12, 1.14, 1.16, 1.18, + 1.2, 1.22, 1.24, 1.26, 1.28, 1.3, 1.32, 1.34, 1.36, 1.38, + 1.4, 1.42, 1.44, 1.46, 1.48, 1.5, 1.52, 1.54, 1.56, 1.58, + 1.6, 1.62, 1.64, 1.66, 1.68, 1.7, 1.72, 1.74, 1.76, 1.78, + 1.8, 1.82, 1.84, 1.86, 1.88, 1.9 + ] \ No newline at end of file diff --git a/examples/weather/diffusion-urban-flows-2D/conf/evaluate/uflow2d_eval.yaml b/examples/weather/diffusion-urban-flows-2D/conf/evaluate/uflow2d_eval.yaml new file mode 100644 index 0000000000..38d1dbc8b2 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/conf/evaluate/uflow2d_eval.yaml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +unconditional: + dir: /opt/Diffusion-PiGDM-Modulus-2D/outputs/diffusion_uflow/uncond_generated_images + stats_eval_results_dir: /opt/Diffusion-PiGDM-Modulus-2D/results/uncond_eval + predicted_snaps: 1000 + eval_ckpt: 1100 + num : 0 diff --git a/examples/weather/diffusion-urban-flows-2D/conf/generation/uflow2d.yaml b/examples/weather/diffusion-urban-flows-2D/conf/generation/uflow2d.yaml new file mode 100644 index 0000000000..f0a1294529 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/conf/generation/uflow2d.yaml @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +total_images: 1000 +batch_size_total: 50 +sampler: + solver: euler + discretization: edm + schedule: linear + rho: 1 + num_steps: 1000 +io: + inf_ckpt: 1100 + inf_ckpt_filepath: .ckpt + # inference Checkpoint filename for the diffusion model + uncond_gen_dir: /opt/Diffusion-PiGDM-Modulus-2D/outputs/diffusion_uflow/uncond_generated_images/ + +perf: + force_fp16: false + # Whether to force fp16 precision for the model. If false, it'll use the precision + # specified upon training. + use_torch_compile: false + # whether to use torch.compile on the diffusion model + # this will make the first time stamp generation very slow due to compilation overheads + # but will significantly speed up subsequent inference runs + num_writer_workers: 1 + # number of workers to use for writing file + # To s diff --git a/examples/weather/diffusion-urban-flows-2D/conf/model/diffusion_uflow.yaml b/examples/weather/diffusion-urban-flows-2D/conf/model/diffusion_uflow.yaml new file mode 100644 index 0000000000..19f5315f6e --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/conf/model/diffusion_uflow.yaml @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: diffusion_uflow + +model_args: + model_channels: 64 + channel_mult: [1, 2, 2, 2] + attn_resolutions: [4, 8] #[16] + num_blocks: 2 + dropout: 0 + channel_mult_emb: 4 \ No newline at end of file diff --git a/examples/weather/diffusion-urban-flows-2D/conf/plot_configs.py b/examples/weather/diffusion-urban-flows-2D/conf/plot_configs.py new file mode 100644 index 0000000000..ddfb927d4d --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/conf/plot_configs.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration settings for matplotlib plotting of flow field data.""" + +plot_dict = dict( + figure=dict( + figsize=(10, 5.5), + dpi=300, + tight_layout=True, + obs_pos_x=-0.125, + obs_pos_y=0, + obs_width=0.25, + obs_height=1, + levels=100, + re_levels=5, + jpdf_level=100, + ), + axes=dict( + x_label=r"$x/h$", # X-axis label + y_label=r"$y/h$", # Y-axis label + fluc_label=[r"$u^{\prime}$", r"$v^{\prime}$", r"$w^{\prime}$"], + re_norm_stresses=[ + r"$\overline{u^{\prime}u^{\prime}}$", + r"$\overline{v^{\prime}v^{\prime}}$", + r"$\overline{w^{\prime}w^{\prime}}$", + ], + re_norm_stresses_p=[ + r"$\overline{{u^{\prime}_p}{u^{\prime}_p}}$", + r"$\overline{{v^{\prime}_p}{v^{\prime}_p}}$", + r"$\overline{{w^{\prime}_p}{w^{\prime}_p}}$", + ], + re_sh_stresses=[ + r"$\overline{u^{\prime}v^{\prime}}$", + r"$\overline{u^{\prime}w^{\prime}}$", + r"$\overline{v^{\prime}w^{\prime}}$", + ], + re_sh_stresses_p=[ + r"$\overline{{u^{\prime}_p}{v^{\prime}_p}}$", + r"$\overline{{u^{\prime}_p}{w^{\prime}_p}}$", + r"$\overline{{v^{\prime}_p}{w^{\prime}_p}}$", + ], + x_lim=(-1, 5), + y_lim=(0, 2), + fontsize=20, + x_ticks=[-1, 0, 1, 2, 3, 4], + y_ticks=[0, 1, 1.8], + ticksize=20, + ), + plot=dict( + snap_cmap="viridis", + re_stress_cmap="rainbow", + jpdf_cmap="RdBu", + ), + legend=dict( + comp_labels=["streamwise", "wall-normal", "spanwise"], + levels=100, + ), +) + + +def basic_plt_setup(): + """Configure matplotlib with publication-quality default settings. + + Sets up matplotlib to use serif fonts with LaTeX rendering and + appropriate font sizes for publication-quality figures. + """ + import matplotlib.pyplot as plt + + plt.rc("font", family="serif") + plt.rc("text", usetex="true") + plt.rc("font", size=30) + plt.rc("axes", labelsize=30, linewidth=2) + plt.rc("legend", fontsize=25, handletextpad=0.1) + plt.rc("xtick", labelsize=25) + plt.rc("ytick", labelsize=25) + + return diff --git a/examples/weather/diffusion-urban-flows-2D/conf/training/diffusion_uflow.yaml b/examples/weather/diffusion-urban-flows-2D/conf/training/diffusion_uflow.yaml new file mode 100644 index 0000000000..938206fb04 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/conf/training/diffusion_uflow.yaml @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Hyperparameters +hp: + epochs : 1000 + + total_batch_size: 2048 #redundant in new training loop + # Total batch size + batch_size_per_gpu: 64 #"auto" + # Batch size per GPU + lr: 1e-3 + # Learning rate + lr_rampup: 0 + # Rampup for learning rate, in number of epochs + grad_clip_threshold: null + # no gradient clipping for defualt non-patch-based training + lr_decay: 1 + # LR decay rate + lr_decay_from : 100 + +# Performance +perf: + fp_optimizations: fp32 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 16 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint +# I/O +io: + # Where to load the regression checkpoint + print_progress_freq: 5 + # How often to print progress + save_checkpoint_freq: 100 + # How often to save the checkpoints, measured in number of processed samples diff --git a/examples/weather/diffusion-urban-flows-2D/datasets/base.py b/examples/weather/diffusion-urban-flows-2D/datasets/base.py new file mode 100644 index 0000000000..14a4937592 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/datasets/base.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base classes and abstractions for dataset implementations.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Tuple + +import numpy as np +import torch + + +@dataclass +class ChannelMetadata: + """Metadata describing a data channel.""" + + name: str + level: str = "" + auxiliary: bool = False + + +class DownscalingDataset(torch.utils.data.Dataset, ABC): + """An abstract class that defines the interface for downscaling datasets.""" + + @abstractmethod + def longitude(self) -> np.ndarray: + """Get longitude values from the dataset.""" + pass + + @abstractmethod + def latitude(self) -> np.ndarray: + """Get latitude values from the dataset.""" + pass + + @abstractmethod + def input_channels(self) -> List[ChannelMetadata]: + """Metadata for the input channels. A list of ChannelMetadata, one for each channel""" + pass + + @abstractmethod + def output_channels(self) -> List[ChannelMetadata]: + """Metadata for the output channels. A list of ChannelMetadata, one for each channel""" + pass + + @abstractmethod + def time(self) -> List: + """Get time values from the dataset.""" + pass + + @abstractmethod + def image_shape(self) -> Tuple[int, int]: + """Get the (height, width) of the data (same for input and output).""" + pass + + def normalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from physical units to normalized data.""" + return x + + def denormalize_input(self, x: np.ndarray) -> np.ndarray: + """Convert input from normalized data to physical units.""" + return x + + def normalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from physical units to normalized data.""" + return x + + def denormalize_output(self, x: np.ndarray) -> np.ndarray: + """Convert output from normalized data to physical units.""" + return x + + def info(self) -> dict: + """Get information about the dataset.""" + return {} diff --git a/examples/weather/diffusion-urban-flows-2D/datasets/dataset.py b/examples/weather/diffusion-urban-flows-2D/datasets/dataset.py new file mode 100644 index 0000000000..4154bb3e3e --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/datasets/dataset.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset initialization and dataloader creation utilities.""" + +from typing import Iterable, Tuple, Union +import copy +import torch + +from torch.utils.data import distributed + +from modulus.distributed import DistributedManager + +from . import base, uflow2d + + +# this maps all known dataset types to the corresponding init function +known_datasets = {"uflow2d": uflow2d.OneObs2D} + + +def init_train_valid_datasets_from_config( + dataset_cfg: dict, + dataloader_cfg: Union[dict, None] = None, + batch_size: int = 1, + seed: int = 0, + validation_dataset_cfg: Union[dict, None] = None, + train_test_split: bool = True, +) -> Tuple[ + base.DownscalingDataset, + Iterable, + Union[base.DownscalingDataset, None], + Union[Iterable, None], +]: + """ + A wrapper function for managing the train-test split for the CWB dataset. + + Parameters: + - dataset_cfg (dict): Configuration for the dataset. + - dataloader_cfg (dict, optional): Configuration for the dataloader. Defaults to None. + - batch_size (int): The number of samples in each batch of data. Defaults to 1. + - seed (int): The random seed for dataset shuffling. Defaults to 0. + - train_test_split (bool): A flag to determine whether to create a validation dataset. Defaults to True. + + Returns: + - Tuple[base.DownscalingDataset, Iterable, Optional[base.DownscalingDataset], Optional[Iterable]]: A tuple containing the training dataset and iterator, and optionally the validation dataset and iterator if train_test_split is True. + """ + + config = copy.deepcopy(dataset_cfg) + (dataset, dataset_iter) = init_dataset_from_config( + config, dataloader_cfg, batch_size=batch_size, seed=seed + ) + if train_test_split: + valid_dataset_cfg = copy.deepcopy(config) + if validation_dataset_cfg: + valid_dataset_cfg.update(validation_dataset_cfg) + (valid_dataset, valid_dataset_iter) = init_dataset_from_config( + valid_dataset_cfg, dataloader_cfg, batch_size=batch_size, seed=seed + ) + else: + valid_dataset = valid_dataset_iter = None + + return dataset, dataset_iter, valid_dataset, valid_dataset_iter + + +def init_dataset_from_config( + dataset_cfg: dict, + dataloader_cfg: Union[dict, None] = None, + batch_size: int = 1, + seed: int = 0, +) -> Tuple[base.DownscalingDataset, Iterable]: + """Initialize a dataset and dataloader from configuration dictionaries. + + Parameters + ---------- + dataset_cfg : dict + Configuration dictionary for the dataset. + dataloader_cfg : dict, optional + Configuration dictionary for the dataloader. + batch_size : int, optional + Number of samples per batch. Default is 1. + seed : int, optional + Random seed for data shuffling. Default is 0. + + Returns + ------- + tuple + A tuple containing (dataset_object, dataset_iterator). + """ + dataset_cfg = copy.deepcopy(dataset_cfg) + dataset_type = dataset_cfg.pop("type", "cwb") + if "train_test_split" in dataset_cfg: + # handled by init_train_valid_datasets_from_config + del dataset_cfg["train_test_split"] + dataset_init_func = known_datasets[dataset_type] + + dataset_obj = dataset_init_func(**dataset_cfg) + if dataloader_cfg is None: + dataloader_cfg = {} + + dist = DistributedManager() + # dataset_sampler = InfiniteSampler( + # dataset=dataset_obj, rank=dist.rank, num_replicas=dist.world_size, seed=seed, shuffle=True, + # ) + dataset_sampler = distributed.DistributedSampler( + dataset=dataset_obj, + rank=dist.rank, + num_replicas=dist.world_size, + seed=seed, + shuffle=True, + ) + + dataset_iterator = iter( + torch.utils.data.DataLoader( + dataset=dataset_obj, + sampler=dataset_sampler, + batch_size=batch_size, + worker_init_fn=None, + **dataloader_cfg, + ) + ) + + return (dataset_obj, dataset_iterator) + + +def get_dataset_and_dataloader_from_config( + dataset_cfg: dict, + dataloader_cfg: Union[dict, None] = None, + batch_size: int = 1, + seed: int = 0, + shuffle=False, + dist=None, + Train=False, +): + """Create a dataset and dataloader from configuration dictionaries. + + Parameters + ---------- + dataset_cfg : dict + Configuration dictionary for the dataset. + dataloader_cfg : dict, optional + Configuration dictionary for the dataloader. + batch_size : int, optional + Number of samples per batch. Default is 1. + seed : int, optional + Random seed for data shuffling. Default is 0. + shuffle : bool, optional + Whether to shuffle the data. Default is False. + dist : DistributedManager, optional + Distributed training manager instance. + Train : bool, optional + Whether this is for training mode. Default is False. + + Returns + ------- + tuple + A tuple containing (dataset_object, dataloader). + """ + + assert dist is not None + + dataset_cfg = copy.deepcopy(dataset_cfg) + dataset_type = dataset_cfg.pop("type", "none") + + assert dataset_type in known_datasets + dataset_init_func = known_datasets[dataset_type] + + dataset_obj = dataset_init_func(**dataset_cfg) + if dataloader_cfg is None: + dataloader_cfg = {} + + if Train: + dist = DistributedManager() + + dataset_sampler = distributed.DistributedSampler( + dataset=dataset_obj, + rank=dist.rank, + num_replicas=dist.world_size, + seed=seed, + shuffle=shuffle, + ) + + dataset_loader = torch.utils.data.DataLoader( + dataset=dataset_obj, + sampler=dataset_sampler, + batch_size=batch_size, + worker_init_fn=None, + **dataloader_cfg, + ) + + return (dataset_obj, dataset_loader) diff --git a/examples/weather/diffusion-urban-flows-2D/datasets/norm.py b/examples/weather/diffusion-urban-flows-2D/datasets/norm.py new file mode 100644 index 0000000000..1244817a2b --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/datasets/norm.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Normalization and denormalization utilities for flow field data.""" + +import numpy as np + + +def normalize(x, u_max, u_min, v_max, v_min): + """ + Normalizes the input tensor `x` with shape (num,2, nx, ny) by using min + and max values from train dataset to range [-1, 1]. + + Parameters: + x (np.ndarray): numpy array with shape (num, 2, h, w). + u_max (float): Maximum value for the first channel. + u_min (float): Minimum value for the first channel. + v_max (float): Maximum value for the second channel. + v_min (float): Minimum value for the second channel. + + Returns: + np.ndarray: normalized tensor with the same shape as `x`, in the range [-1 to 1]. + """ + x = np.clip(x, a_min=-1, a_max=1) + + eps = 1e-9 + center = np.array([u_min, v_min]).reshape((2, 1, 1)) + scale = np.array([u_max - u_min, v_max - v_min]).reshape((2, 1, 1)) + x_scaled = (x - center) / (scale + eps) + + return (2 * x_scaled) - 1 + + +def renormalize(x_norm, u_max, u_min, v_max, v_min): + """ + Renormalizes the input tensor `x_norm` with shape (num,2, h, w) from the range [-1, 1] + back to the original range defined by u_max, u_min, v_max, and v_min. + + Parameters: + x_norm (np.ndarray): Normalized tensor with shape (num,2, h, w) in the range [-1, 1]. + u_max (float): Maximum value for the first channel. + u_min (float): Minimum value for the first channel. + v_max (float): Maximum value for the second channel. + v_min (float): Minimum value for the second channel. + + Returns: + np.ndarray: Renormalized tensor with the same shape as `x_norm`, in the original range. + """ + + eps = 1e-9 # Small epsilon to avoid division by zero + center = np.array([u_min, v_min]).reshape((2, 1, 1)) + scale = np.array([u_max - u_min, v_max - v_min]).reshape((2, 1, 1)) + + # Scale back to [0, 1] range + x_rescaled = (x_norm + 1) / 2 + + # Shift and scale back to original range + x_renormalized = x_rescaled * (scale + eps) + center + + return x_renormalized diff --git a/examples/weather/diffusion-urban-flows-2D/datasets/uflow2d.py b/examples/weather/diffusion-urban-flows-2D/datasets/uflow2d.py new file mode 100644 index 0000000000..b51bef2e1a --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/datasets/uflow2d.py @@ -0,0 +1,380 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for loading and processing 2D optical flow data from HDF5 files. + +This module provides utilities for loading 2D velocity field data (u and v components) +from HDF5 files, preprocessing them with downsampling and normalization, and exposing +them as a PyTorch Dataset for training and evaluation. +""" + +import h5py +import numpy as np +from torch.utils.data import Dataset +from helpers.general_helpers import combine_fields + + +def read_hdf5(data_path): + """Open an HDF5 file in read mode without file locking. + + Parameters + ---------- + data_path : str + Path to the HDF5 file to be opened. + + Returns + ------- + h5py.File + An opened HDF5 file object in read mode with locking disabled. + The file contains keys such as 'u_fluc', 'v_fluc', 'x', 'y', 't', and 'means' + for velocity components and spatial/temporal coordinates. + """ + import h5py + + data_arr = h5py.File(data_path, "r", locking=False) + return data_arr + + +class OneObs2D(Dataset): + """PyTorch Dataset for 2D optical flow velocity field data. + + This dataset loads 2D velocity field data (u and v components) from an HDF5 file, + applies spatial downsampling and optional normalization, and provides samples for + training or evaluation. The data is expected to be structured with time-dependent + velocity components on a 2D grid. + + Attributes + ---------- + data : np.ndarray + Stacked velocity components with shape (time, 2, height, width) where + the channel dimension contains [u, v] components. + x : np.ndarray + X-axis spatial coordinates after downsampling. + y : np.ndarray + Y-axis spatial coordinates after downsampling. + t : np.ndarray + Time coordinates for each snapshot. + means : np.ndarray + Mean velocity field with shape (time, 2, height, width). + num : int + Number of temporal snapshots. + channels : int + Number of velocity components (always 2 for u, v). + nx : int + Number of grid points in x-direction after downsampling. + ny : int + Number of grid points in y-direction after downsampling. + u_min, u_max : float + Minimum and maximum values of u velocity component. + v_min, v_max : float + Minimum and maximum values of v velocity component. + """ + + def __init__( + self, + data_path=None, + filetype="hdf5", + transform=None, + ds_ratio=1, + normalize=True, + image_size=None, + ): + """Initialize the 2D optical flow dataset. + + Parameters + ---------- + data_path : str, optional + Path to the HDF5 data file. Must be provided. + filetype : str, default="hdf5" + Format of the data file. Currently only "hdf5" is supported. + transform : callable, optional + Optional transformation to apply to each sample (e.g., torch transforms). + ds_ratio : int, default=1 + Downsampling ratio. Supported values are 1, 2, and 5. + - 1 or 2: Data is cropped to (288, 96) then downsampled by ds_ratio + - 5: Data is cropped to (300, 100) then downsampled by ds_ratio + normalize : bool, default=True + Whether to normalize velocity values to [-1, 1] range. + image_size : tuple, optional + Not currently used. Kept for API compatibility. + + Raises + ------ + AssertionError + If data_path is None or if file dimensions do not match expected shapes. + NotImplementedError + If filetype is not "hdf5" or if ds_ratio is not in [1, 2, 5]. + """ + + assert data_path is not None + self.normalize = normalize + self.transform = transform + + if filetype == "hdf5": + # ------------------------------------------------------------------ + data_arr = read_hdf5(data_path) + u = np.asarray(data_arr["u_fluc"][:], dtype=np.float32) # time, nx, ny + v = np.asarray(data_arr["v_fluc"][:], dtype=np.float32) # time, nx, ny + x = np.asarray(data_arr["x"][:]) + y = np.asarray(data_arr["y"][:]) + t = np.asarray(data_arr["t"][:]) + means = np.asarray(data_arr["means"][:], dtype=np.float32) + + nx, ny = 301, 101 + assert u.shape[1:] == (nx, ny) and v.shape[1:] == (nx, ny) + assert ( + x.shape == (nx,) + and y.shape == (ny,) + and t.shape == (u.shape[0], 1) + and t.shape == (v.shape[0], 1) + ) + assert means.shape[1:] == (nx, ny) + # ------------------------------------------------------------------ + # skip the wrongly interpolated data + # skip_index = [185, 2490, 2718] + # u = remove_indices_from_array(array=u, indices=skip_index) + # v = remove_indices_from_array(array=v, indices=skip_index) + + if ds_ratio == 5: + u = u[:, :-1, :-1] + v = v[:, :-1, :-1] + means = means[:, :-1, :-1] + x = x[:-1] + y = y[:-1] + nx, ny = u.shape[1:] + assert (nx, ny) == (300, 100) + # shape = (300,100) -> (60,20) # only use 2 layers + + elif ds_ratio == 2 or ds_ratio == 1: + u = u[:, :-13, :-5] + v = v[:, :-13, :-5] + means = means[:, :-13, :-5] + x = x[:-13] + y = y[:-5] + nx, ny = u.shape[1:] + assert (nx, ny) == (288, 96) + # shape = (288,96) -> (144,48) # only use upto 4 layers + + else: + print(f"ds_ratio {ds_ratio} not supported") + raise NotImplementedError + + # ------------------------------------------------------------------ + + # downsampling by ds_ratio + u = u[:, ::ds_ratio, ::ds_ratio] + v = v[:, ::ds_ratio, ::ds_ratio] + + assert u.shape[1:] == v.shape[1:] # == image_size + + self.x = x[::ds_ratio] + self.y = y[::ds_ratio] + self.t = t + self.means = means[:, ::ds_ratio, ::ds_ratio] + self.data = np.stack((u, v), axis=1) # time, nc, nx, ny + + self.num, self.channels, self.nx, self.ny = self.data.shape + + self.u_min, self.u_max = np.min(u), np.max(u) + self.v_min, self.v_max = np.min(v), np.max(v) + + # ------------------------------------------------------------------------------ + assert ( + self.data.shape[1:] == (2, nx // ds_ratio, ny // ds_ratio) + and self.data.dtype == np.float32 + ) + + else: + raise NotImplementedError + + def __len__(self): + """Return the total number of samples in the dataset. + + Returns + ------- + int + Number of temporal snapshots in the dataset. + """ + return len(self.data) + + def __getitem__(self, idx): + """Retrieve a single velocity field sample by index. + + Parameters + ---------- + idx : int + Index of the sample to retrieve. Must be in range [0, len(self)). + + Returns + ------- + np.ndarray or torch.Tensor + Velocity field at the requested index with shape (2, height, width) containing + [u, v] components. If normalization is enabled, values are in [-1, 1] range. + If a transform is provided, the output is transformed accordingly. + """ + image = self.data[idx] + if self.normalize: + image = self.__normalize(image) + if self.transform is not None: + image = self.transform(image) + return image + + def __normalize(self, x): + """Normalize velocity field to [-1, 1] range using min-max scaling. + + Performs min-max normalization on the u and v components independently, + then linearly maps the result from [0, 1] to [-1, 1] range. + + Parameters + ---------- + x : np.ndarray + Velocity field with shape (2, height, width) containing [u, v] components + in their original data range. + + Returns + ------- + np.ndarray + Normalized velocity field with shape (2, height, width) where values + are in the range [-1, 1]. Each component is normalized using its + respective min and max values computed over the entire dataset. + """ + + # x shape = (2, h, w) + eps = 1e-9 + center = np.array([self.u_min, self.v_min]).reshape((2, 1, 1)) + scale = np.array([self.u_max - self.u_min, self.v_max - self.v_min]).reshape( + (2, 1, 1) + ) + x_scaled = (x - center) / (scale + eps) + return (2 * x_scaled) - 1 + + def num_channels(self): + """Number of channels in the datasets""" + return self.channels + + def image_shape(self): + """Shape of the 2D image""" + return (self.nx, self.ny) + + +def get_data_for_evaluation( + data_path=None, dim="2D", Train=False, Test=False, ds_ratio=1 +): + """Load flow field data for evaluation purposes. + + Parameters + ---------- + data_path : str, optional + Path to the HDF5 data file. + dim : str, optional + Dimensionality of the data. Default is "2D". + Train : bool, optional + If True, load training data. Default is False. + Test : bool, optional + If True, load test data. Default is False. + ds_ratio : int, optional + Downsampling ratio. Default is 1 (no downsampling). + + Returns + ------- + tuple + A tuple containing (data, x, y, t) where data is the combined + velocity fields and x, y, t are the spatial and temporal coordinates. + """ + if dim == "2D": + if Train: + OneObs = OneObs2D(data_path=data_path) + data = OneObs.data + x, y, t = OneObs.x, OneObs.y, OneObs.t + print(f"Train data loaded: {data.shape} !!") + + if Test: + data = h5py.File(data_path) + u = data["u_fluc"][:] + v = data["v_fluc"][:] + x = data["x"][:] + y = data["y"][:] + t = data["t"][:] + + if ds_ratio == 1: + u = u[:, :-13, :-5] + v = v[:, :-13, :-5] + x = x[:-13] + y = y[:-5] + + assert u.shape[1:] == v.shape[1:] == (288, 96) + # shape = (288,96) -> (144,48) # only use upto 4 layers + + data = combine_fields(u=u, v=v) + print(f"Test data loaded: {data.shape} !!") + + return data, x, y, t + + +def get_data_for_evaluation_with_min_max( + data_path=None, dim="2D", Train=False, Test=False, ds_ratio=1 +): + """Load flow field data for evaluation with normalization min/max values. + + Parameters + ---------- + data_path : str, optional + Path to the HDF5 data file. + dim : str, optional + Dimensionality of the data. Default is "2D". + Train : bool, optional + If True, load training data. Default is False. + Test : bool, optional + If True, load test data. Default is False. + ds_ratio : int, optional + Downsampling ratio. Default is 1 (no downsampling). + + Returns + ------- + tuple + A tuple containing (data, x, y, t, u_min, u_max, v_min, v_max) where + data is the combined velocity fields, x, y, t are coordinates, and + the min/max values are for denormalization. + """ + if dim == "2D": + OneObs = OneObs2D(data_path=data_path) + u_min, u_max = OneObs.u_min, OneObs.u_max + v_min, v_max = OneObs.v_min, OneObs.v_max + + print(f"u_min = {u_min}, u_max = {u_max}, u_min = {v_min}, u_max = {v_max}") + + if Train: + data = OneObs.data + x, y, t = OneObs.x, OneObs.y, OneObs.t + print(f"Train data loaded: {data.shape} !!") + + if Test: + data = h5py.File(data_path) + u = data["u_fluc"][:] + v = data["v_fluc"][:] + t = data["t"][:] + if ds_ratio == 1: + u = u[:, :-13, :-5] + v = v[:, :-13, :-5] + + nx, ny = u.shape[1:] + assert u.shape[1:] == v.shape[1:] == (nx, ny) == (288, 96) + # shape = (288,96) -> (144,48) # only use upto 4 layers + + data = combine_fields(u=u, v=v) + print(f"Test data loaded: {data.shape} !!") + + return data, x, y, t, u_min, u_max, v_min, v_max diff --git a/examples/weather/diffusion-urban-flows-2D/evaluate-uncond-gen-2D.py b/examples/weather/diffusion-urban-flows-2D/evaluate-uncond-gen-2D.py new file mode 100644 index 0000000000..ae5c607c28 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/evaluate-uncond-gen-2D.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Statistical evaluation of unconditionally generated 2D flow fields.""" + +import hydra +from matplotlib.backends.backend_pdf import PdfPages +from omegaconf import DictConfig + +from datasets.uflow2d import get_data_for_evaluation +from helpers.evaluate_helpers import load_predicted_data +from helpers.stats_helpers import StatisticalEvaluation + + +@hydra.main(version_base="1.2", config_path="conf", config_name="config_generate_uflow") +def main(cfg: DictConfig) -> None: + """Perform statistical evaluation on unconditionally generated 2D flow fields. + + This function loads training data and predicted flow fields, performs + statistical evaluations including Reynolds stress analysis, power spectral + density analysis, and joint PDFs, and saves the results as a PDF report. + + Parameters + ---------- + cfg : DictConfig + Hydra configuration containing paths to training and predicted data, + evaluation parameters, and checkpoint information. + """ + # Get important directories/paths + train_data_path = cfg.dataset.data_path + pred_data_path = f"{cfg.evaluate.unconditional.dir}/Ep-{cfg.generation.io.inf_ckpt}-stp-{cfg.generation.sampler.num_steps}-uncond-snaps-{cfg.generation.total_images / 1000}k.h5" + epoch = cfg.evaluate.unconditional.eval_ckpt + stats_eval_results_dir = cfg.evaluate.unconditional.stats_eval_results_dir + num = cfg.evaluate.unconditional.num + + # Load Train/Test data ( with mins and maxs) + # data, x_axis, y_axis, t, u_min, u_max, v_min, v_max = get_data_for_evaluation(data_path=train_data_path, Train=True) + + data, x_axis, y_axis, t = get_data_for_evaluation( + data_path=train_data_path, Train=True + ) + + # Load predicted data + pred_data = load_predicted_data(pred_data_path) + + # renormalize predicted data (renorm done in the loaded data directly) + # pred_data = renormalize(pred_data, u_min= u_min, u_max= u_max, v_min= v_min, v_max= v_max) + + # check/assert shapes + assert data.shape[1:] == pred_data.shape[1:], ( + "Image resolution for gtruth and prediction should match" + ) + + # Statistical evaluation of the Model + output_file_path = stats_eval_results_dir + f"/epoch-{epoch}" + stats_eval = StatisticalEvaluation( + gtruth=data, + pred=pred_data, + x_axis=x_axis, + y_axis=y_axis, + time=t, + input_data_type="2D", + data="line-x", + output_file_path=output_file_path, + ) + + print(f"Saving pdf at {stats_eval_results_dir}") + + pdf = PdfPages( + f"{output_file_path}/stats-pred-snaps-{cfg.evaluate.unconditional.predicted_snaps}-inst{num}.pdf" + ) + pdf = stats_eval.main( + num=num, locations=[0.5, 1, 2, 3, 4], y=0.5, pdf=pdf + ) # locations = locations along x, where the PSD needs to be calculated + pdf.close() + + +if __name__ == "__main__": + main() diff --git a/examples/weather/diffusion-urban-flows-2D/generate.py b/examples/weather/diffusion-urban-flows-2D/generate.py new file mode 100644 index 0000000000..8ee43cd747 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/generate.py @@ -0,0 +1,373 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generation script for unconditionally sampling from trained diffusion models.""" + +import os +import hydra +import numpy as np +import torch +import tqdm +from omegaconf import DictConfig +from modulus.utils.generative.utils import StackedRandomGenerator + +from modulus import Module +from modulus.distributed import DistributedManager +from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper + + +def sampler( + net, + latents, + class_labels=None, + randn_like=torch.randn_like, + num_steps=18, + sigma_min=None, + sigma_max=None, + rho=7, + solver="heun", + discretization="edm", + schedule="linear", + scaling="none", + epsilon_s=1e-3, + C_1=0.001, + C_2=0.008, + M=1000, + alpha=1, + s_churn=0, + s_min=0, + s_max=float("inf"), + s_noise=1, +): + """ + Generalized sampler, representing the superset of all sampling methods discussed + in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" + """ + if solver not in ["euler", "heun"]: + raise ValueError(f'Invalid solver "{solver}"') + if discretization not in ["vp", "ve", "iddpm", "edm"]: + raise ValueError(f'Invalid discretization "{discretization}"') + if schedule not in ["vp", "ve", "linear"]: + raise ValueError(f'Invalid schedule "{schedule}"') + if scaling is not None and scaling not in ["vp"]: + raise ValueError(f'Invalid scaling "{scaling}"') + + # Helper functions for VP & VE noise level schedules. + def vp_sigma(beta_d, beta_min): + return lambda t: (np.e ** (0.5 * beta_d * (t**2) + beta_min * t) - 1) ** 0.5 + + def vp_sigma_deriv(beta_d, beta_min): + return lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) + + def vp_sigma_inv(beta_d, beta_min): + return ( + lambda sigma: ( + (beta_min**2 + 2 * beta_d * (sigma**2 + 1).log()).sqrt() - beta_min + ) + / beta_d + ) + + def ve_sigma(t): + return t.sqrt() + + def ve_sigma_deriv(t): + return 0.5 / t.sqrt() + + def ve_sigma_inv(sigma): + return sigma**2 + + # Select default noise level range based on the specified time step discretization. + if sigma_min is None: + vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s) + sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[ + discretization + ] + if sigma_max is None: + vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1) + sigma_max = {"vp": vp_def, "ve": 100, "iddpm": 81, "edm": 80}[discretization] + + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Compute corresponding betas for VP. + vp_beta_d = ( + 2 + * (np.log(sigma_min**2 + 1) / epsilon_s - np.log(sigma_max**2 + 1)) + / (epsilon_s - 1) + ) + vp_beta_min = np.log(sigma_max**2 + 1) - 0.5 * vp_beta_d + + # Define time steps in terms of noise level. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + if discretization == "vp": + orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) + sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) + elif discretization == "ve": + orig_t_steps = (sigma_max**2) * ( + (sigma_min**2 / sigma_max**2) ** (step_indices / (num_steps - 1)) + ) + sigma_steps = ve_sigma(orig_t_steps) + elif discretization == "iddpm": + u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) + + def alpha_bar(j): + return (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 + + for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 + u[j - 1] = ( + (u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1 + ).sqrt() + u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] + sigma_steps = u_filtered[ + ((len(u_filtered) - 1) / (num_steps - 1) * step_indices) + .round() + .to(torch.int64) + ] + else: # edm sigma steps + sigma_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + + # Define noise level schedule. + if schedule == "vp": + sigma = vp_sigma(vp_beta_d, vp_beta_min) + sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) + sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) + elif schedule == "ve": + sigma = ve_sigma + sigma_deriv = ve_sigma_deriv + sigma_inv = ve_sigma_inv + else: + + def sigma(t): + return t + + def sigma_deriv(t): + return 1 + + def sigma_inv(sigma): + return sigma + + # Define scaling schedule. + if scaling == "vp": + + def s(t): + return 1 / (1 + sigma(t) ** 2).sqrt() + + def s_deriv(t): + return -sigma(t) * sigma_deriv(t) * (s(t) ** 3) + else: + + def s(t): + return 1 + + def s_deriv(t): + return 0 + + # Compute final time steps based on the corresponding noise levels. + t_steps = sigma_inv(net.round_sigma(sigma_steps)) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + t_next = t_steps[0] + x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = ( + min(s_churn / num_steps, np.sqrt(2) - 1) + if s_min <= sigma(t_cur) <= s_max + else 0 + ) + t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) + x_hat = s(t_hat) / s(t_cur) * x_cur + ( + sigma(t_hat) ** 2 - sigma(t_cur) ** 2 + ).clip(min=0).sqrt() * s(t_hat) * s_noise * randn_like(x_cur) + + # Euler step. + h = t_next - t_hat + denoised = ( + net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64).detach() + ) + d_cur = ( + sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat) + ) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised + x_prime = x_hat + alpha * h * d_cur + t_prime = t_hat + alpha * h + + # Apply 2nd order correction. + if solver == "euler" or i == num_steps - 1: + x_next = x_hat + h * d_cur + else: + assert solver == "heun" + denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to( + torch.float64 + ) + d_prime = ( + sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime) + ) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised + x_next = x_hat + h * ( + (1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime + ) + + return x_next + + +@hydra.main(version_base="1.2", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + """Generate random images using the techniques described in the paper + "Elucidating the Design Space of Diffusion-Based Generative Models". + """ + + img_outdir = cfg.generation.io.uncond_gen_dir + inf_ckpt = cfg.generation.io.inf_ckpt + inf_ckpt_path = ( + f"{cfg.generation.io.inf_ckpt_filepath}/EDMPrecond.0.{inf_ckpt}.mdlus" + ) + + total_images = ( + cfg.generation.total_images + ) # total images we want to generate unconditinoally + gen_seeds = list( + np.arange(total_images) + ) # use a different (but fixed) seed for each unconditional sample + max_batch_size = cfg.generation.batch_size_total # max_batch_size per gpu + + # Initialize distributed manager. + DistributedManager.initialize() + dist = DistributedManager() + device = dist.device + + # Initialize logger. + logger = PythonLogger("main") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) + logger.file_logging() + + num_batches = ( + (len(gen_seeds) - 1) // (max_batch_size * dist.world_size) + 1 + ) * dist.world_size + all_batches = torch.as_tensor(gen_seeds).tensor_split(num_batches) + rank_batches = all_batches[ + dist.rank :: dist.world_size + ] # tuple of (batches with seeds) on same rank + + if dist.world_size > 1 and dist.rank != 0: + torch.distributed.barrier() + + logger0.info(f'Loading residual network from "{inf_ckpt_path}"...') + + # Load diffusion network, move to device, change precision + net = Module.from_checkpoint(hydra.utils.to_absolute_path(inf_ckpt_path)) + # net = + net.eval().to(device).to(memory_format=torch.channels_last) + if cfg.generation.perf.force_fp16: + net.use_fp16 = True + + assert net is not None, "diffusion must be loaded!" + + # Other ranks follow. + if dist.world_size > 1 and dist.rank == 0: + torch.distributed.barrier() + + solver = cfg.generation.sampler.solver # add support for heun solver in the future + discretization = ( + cfg.generation.sampler.discretization + ) # add support for other discretizations in the future + schedule = ( + cfg.generation.sampler.schedule + ) # add support for other schedules in the future + sigma_min = None # In the Sampler function above, typical choices for edm are set within the code + sigma_max = 20 # In the Sampler function above, typical choices for edm are set within the code + # which is typicall set to 80, but for high-res images higher value is recommended + # and needs to be set manually + rho = ( + cfg.generation.sampler.rho + ) # Default from current sampler function, may change + num_steps = cfg.generation.sampler.num_steps + # device=dist.device + + # Loop over batches. + logger0.info( + f'Generating {len(gen_seeds)}, images using epoch {inf_ckpt} model, to "{img_outdir}"...' + ) + for batch_seeds in tqdm.tqdm(rank_batches, unit="batch", disable=(dist.rank != 0)): + if dist.world_size > 1: + torch.distributed.barrier() + batch_size = len(batch_seeds) + if batch_size == 0: + continue + + # Pick latents and labels. + rnd = StackedRandomGenerator(device, batch_seeds) + # print(f"culprits={net.img_channels, net.img_resolution, net.img_resolution}") + latents = rnd.randn( + [batch_size, net.img_channels, *net.img_resolution], + device=device, + ) + + class_labels = None # add support for conditional generation in the future + + assert ( + net.label_dim == 0 + ) # add support for conditional generation in the future + + if net.label_dim: # This will be redundant for now! + assert False + class_labels = torch.eye(net.label_dim, device=device)[ + rnd.randint(net.label_dim, size=[batch_size], device=device) + ] + + # Generate images. + preds = sampler( + net, + latents, + class_labels=class_labels, # This is None for now, from above + randn_like=rnd.randn_like, + num_steps=num_steps, + sigma_min=sigma_min, + sigma_max=sigma_max, + rho=rho, + solver=solver, + discretization=discretization, + schedule=schedule, + scaling=None, + ).detach() + + preds_np = np.asarray(preds.cpu().numpy(), dtype=np.float32) + + for seed, pred_np in zip(batch_seeds, preds_np): + image_dir = img_outdir + os.makedirs(image_dir, exist_ok=True) + image_path = os.path.join(image_dir, f"{seed:06d}.npy") + np.save(image_path, pred_np) + + # Done. + if dist.world_size > 1: + torch.distributed.barrier() + logger0.info("Done.") + + +# ---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +# ---------------------------------------------------------------------------- diff --git a/examples/weather/diffusion-urban-flows-2D/helpers/evaluate_helpers.py b/examples/weather/diffusion-urban-flows-2D/helpers/evaluate_helpers.py new file mode 100644 index 0000000000..b919b55895 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/helpers/evaluate_helpers.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for evaluating model predictions.""" + +import h5py + + +def load_predicted_data(data_path=None, filetype="h5"): + """Load predicted data from a file. + + Parameters + ---------- + data_path : str, optional + Path to the file containing predicted data. + filetype : str, optional + Type of file to load. Currently supports "h5" (HDF5 format). + Default is "h5". + + Returns + ------- + numpy.ndarray + Array containing the predicted data. + """ + if filetype == "h5": + hf = h5py.File(data_path, "r") + pred_data = hf["pred"][:] + + print(f"Predicted data loaded : {pred_data.shape}!!") + + return pred_data diff --git a/examples/weather/diffusion-urban-flows-2D/helpers/general_helpers.py b/examples/weather/diffusion-urban-flows-2D/helpers/general_helpers.py new file mode 100644 index 0000000000..c209a40b8c --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/helpers/general_helpers.py @@ -0,0 +1,684 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +General utility helper functions for coordinate conversion, data extraction, and visualization. + +This module provides a collection of utility functions for: +- Converting between unit and pixel coordinates +- Extracting data from 2D and 3D arrays based on spatial coordinates +- Combining velocity fields into multi-dimensional arrays +- Statistical analysis and visualization of flow field data +- Plotting utilities for creating comparison plots and error visualizations +""" + +import numpy as np +from typing import Optional +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import argparse + +from conf.plot_configs import plot_dict + + +def convert_unit2pixel( + x: float = None, + y: float = None, + z: float = None, + ds_ratio: float = None, + flipped: bool = True, +) -> tuple: + """ + Convert unit coordinates to pixel coordinates. + + Parameters: + x (float): X-coordinate in original units. + y (float): Y-coordinate in original units. + z (float): Z-coordinate in original units. + ds_ratio (float): Downsampling ratio. + flipped (bool): Whether the z-coordinate is flipped or not. + + Returns: + tuple: (x_pixel, y_pixel, z_pixel) coordinates in pixels. + """ + if ds_ratio is None: + raise ValueError("ds_ratio must be provided") + + x_pixel, y_pixel, z_pixel = None, None, None + + if x is not None: + # x goes from -1 to 5 in original dataset + x_pixel = int(x * (301 // ds_ratio) / (5 - -1) + 50 // ds_ratio) + + if y is not None: + y_pixel = int(y * (101 // ds_ratio) / (2 - 0)) + + if z is not None: + if flipped: + z_pixel = int(z * 151 / (1.5 - -1.5)) + else: + z_pixel = int(151 / 2 + z * 151 / (1.5 - -1.5)) + + return x_pixel, y_pixel, z_pixel + + +def convert_pixel2unit( + x_pixel: int = None, + y_pixel: int = None, + z_pixel: int = None, + ds_ratio: float = None, + flipped: bool = True, +) -> tuple: + """ + Convert pixel coordinates to unit coordinates. + + Parameters: + x_pixel (int): X-coordinate in pixels. + y_pixel (int): Y-coordinate in pixels. + z_pixel (int): Z-coordinate in pixels. + ds_ratio (float): Downsampling ratio. + flipped (bool): Whether the z-coordinate is flipped or not. + + Returns: + tuple: (x, y, z) coordinates in original units. + """ + if ds_ratio is None: + raise ValueError("ds_ratio must be provided") + + x, y, z = None, None, None + + if x_pixel is not None: + # Reverse the transformation for x + x = (x_pixel - 50 // ds_ratio) * (5 - (-1)) / (301 // ds_ratio) + + if y_pixel is not None: + # Reverse the transformation for y + y = y_pixel * (2 - 0) / (101 // ds_ratio) + + if z_pixel is not None: + # Reverse the transformation for z + if flipped: + z = z_pixel * (1.5 - (-1.5)) / 151 + else: + z = (z_pixel - 151 / 2) * (1.5 - (-1.5)) / 151 + + return x, y, z + + +def combine_fields( + u: Optional[np.ndarray] = None, + v: Optional[np.ndarray] = None, + w: Optional[np.ndarray] = None, +) -> np.ndarray: + """ + Combine the fields u, v, and optionally w into a single numpy array. + + Parameters: + u (np.ndarray): The u field array with shape [26000, 144, 48]. + v (np.ndarray): The v field array with shape [26000, 144, 48]. + w (np.ndarray, optional): The w field array with shape [26000, 144, 48]. Defaults to None. + + Returns: + np.ndarray: Combined array with shape [26000, 2, 144, 48] or [26000, 3, 144, 48]. + """ + + if u is None or v is None: + raise ValueError("Atleast two fields must be provided") + + # Add a new axis to both u and v + u_expanded = np.expand_dims(u, axis=1) + v_expanded = np.expand_dims(v, axis=1) + + if w is not None: + w_expanded = np.expand_dims(w, axis=1) + # Concatenate u, v, and w along the new axis + data = np.concatenate((u_expanded, v_expanded, w_expanded), axis=1) + else: + # Concatenate only u and v along the new axis + data = np.concatenate((u_expanded, v_expanded), axis=1) + + # print(f"Shape of the combined data: {data.shape}") # Should print either (26000, 2, 144, 48) or (26000, 3, 144, 48) + + return data + + +def extract_data_2d( + U: np.ndarray, + x_pixel: int, + y_pixel: Optional[int], + z_pixel: Optional[int], + data: str, +) -> np.ndarray: + """ + Extract data from a 2D input array based on the specified data type. + + Args: + U (np.ndarray): Input 2D array. + x_pixel (int): X coordinate in pixels. + y_pixel (Optional[int]): Y coordinate in pixels (can be None). + z_pixel (Optional[int]): Z coordinate in pixels (can be None). + data (str): Type of data to extract ('point', 'line-x', 'line-y', 'plane'). + + Returns: + np.ndarray: Extracted data. + """ + if z_pixel is None: # xy plane + if data == "point": + return U[:, x_pixel, y_pixel] + elif data == "line-x": + return U[:, :, y_pixel] + elif data == "line-y": + return U[:, x_pixel, :] + elif data == "plane": + return U[:, :, :] + elif y_pixel is None: # xz plane + if data == "point": + return U[:, x_pixel, z_pixel] + elif data == "line-x": + return U[:, :, z_pixel] + elif data == "line-z": + return U[:, x_pixel, :] + elif data == "plane": + return U[:, :, :] + + +def extract_data_3d( + U: np.ndarray, x_pixel: int, y_pixel: int, z_pixel: int, data: str +) -> np.ndarray: + """ + Extract data from a 3D input array based on the specified data type. + + Args: + U (np.ndarray): Input 3D array. + x_pixel (int): X coordinate in pixels. + y_pixel (int): Y coordinate in pixels. + z_pixel (int): Z coordinate in pixels. + data (str): Type of data to extract ('point', 'line-x', 'line-y', 'xy_plane', 'xz_plane'). + + Returns: + np.ndarray: Extracted data. + """ + if data == "point": + return U[:, x_pixel, y_pixel, z_pixel] + elif data == "line-x": + return U[:, :, y_pixel, z_pixel] + elif data == "line-y": + return U[:, x_pixel, :, z_pixel] + elif data == "xy_plane": + return U[:, :, :, z_pixel] + elif data == "xz_plane": + return U[:, :, y_pixel, :] + + +def get_data_for_stats( + U: np.ndarray, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + input_data_type: str = "2D", + data: Optional[str] = None, + mean_over_time: bool = False, + ds_ratio: Optional[float] = None, +) -> np.ndarray: + """ + Extract specific data from a multidimensional array for statistical analysis. + + Args: + U (np.ndarray): Input array. + x (Optional[float]): X coordinate in units. + y (Optional[float]): Y coordinate in units. + z (Optional[float]): Z coordinate in units. + input_data_type (str): Type of input data ('2D' or '3D'). + data (Optional[str]): Type of data to extract ('point', 'line-x', 'line-y', 'line-z', 'xy_plane', 'xz_plane', 'plane'). + mean_over_time (bool): Whether to average the extracted data over time. + ds_ratio (Optional[float]): Downsampling ratio. + + Returns: + np.ndarray: Extracted (and possibly averaged) data. + """ + # Convert units to pixel values + x_pixel, y_pixel, z_pixel = convert_unit2pixel( + x=x, y=y, z=z, ds_ratio=ds_ratio, flipped=False + ) + + # Extract data based on the input data type + if input_data_type == "2D": + extracted_data = extract_data_2d(U, x_pixel, y_pixel, z_pixel, data) + elif input_data_type == "3D": + extracted_data = extract_data_3d(U, x_pixel, y_pixel, z_pixel, data) + else: + raise ValueError(f"Invalid input_data_type: {input_data_type}") + + # Optionally average the extracted data over time + if mean_over_time: + extracted_data = np.mean(extracted_data, axis=0) + + return extracted_data + + +def select_random(arr, num_elements=1, seed=None): + """ + Randomly select one or more elements from a NumPy array. + + Args: + arr (numpy.ndarray): Input array. + num_elements (int): Number of elements to select randomly. Default is 1. + + Returns: + numpy.ndarray or list: Randomly selected element or list of elements. + """ + if num_elements < 1: + raise ValueError("Number of elements to select must be at least 1.") + + np.random.seed(seed) + + # Generate random indices within the range of the array + random_indices = np.random.choice(arr.shape[0], size=num_elements, replace=False) + + # Return the selected elements + if num_elements == 1: + return arr[random_indices[0]] + else: + return arr[random_indices] + + +def dict2namespace(config): + """ + Convert a nested dictionary to an argparse.Namespace object. + + This function recursively converts dictionary configurations into a namespace + object, which allows accessing dictionary keys as object attributes using + dot notation (e.g., namespace.key instead of dict['key']). + + Parameters + ---------- + config : dict + A dictionary configuration to be converted. Can contain nested dictionaries + which will be recursively converted to nested Namespace objects. + + Returns + ------- + argparse.Namespace + A namespace object with dictionary keys as attributes. Nested dictionaries + are recursively converted to nested Namespace objects. + + Examples + -------- + >>> config = {'a': 1, 'b': {'c': 2, 'd': 3}} + >>> ns = dict2namespace(config) + >>> ns.a + 1 + >>> ns.b.c + 2 + """ + namespace = argparse.Namespace() + for key, value in config.items(): + if isinstance(value, dict): + new_value = dict2namespace(value) + else: + new_value = value + setattr(namespace, key, new_value) + return namespace + + +##Common plot Utils +plot_config = dict2namespace(plot_dict) + + +def add_obstacle_patch(ax, color="k"): + """ + Add a rectangular obstacle patch to a matplotlib axes object. + + This function draws a rectangular patch representing an obstacle in a flow field + visualization. The obstacle position and dimensions are retrieved from the global + plot configuration. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The matplotlib axes object on which to add the obstacle patch. + color : str, optional + The color of the obstacle patch. Can be a color name (e.g., 'k' for black, + 'r' for red) or a hex color code. Default is 'k' (black). + + Returns + ------- + None + + Notes + ----- + The obstacle dimensions and position are obtained from the global `plot_config` + object, specifically from: + - plot_config.figure.obs_pos_x: x-coordinate of the obstacle position + - plot_config.figure.obs_pos_y: y-coordinate of the obstacle position + - plot_config.figure.obs_width: width of the obstacle + - plot_config.figure.obs_height: height of the obstacle + + The obstacle is drawn as a filled rectangle with a visible edge. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> fig, ax = plt.subplots() + >>> add_obstacle_patch(ax, color='red') + >>> ax.set_xlim(-1, 5) + >>> ax.set_ylim(0, 2) + >>> plt.show() + """ + # Obstacle dimensions & location (For one obstacle dataset) + pos_x, pos_y = ( + plot_config.figure.obs_pos_x, + plot_config.figure.obs_pos_y, + ) # x position, y position + width, height = ( + plot_config.figure.obs_width, + plot_config.figure.obs_height, + ) # width, height of the obstacle + obstacle = patches.Rectangle( + (pos_x, pos_y), width, height, linewidth=2, edgecolor=color, facecolor=color + ) + ax.add_patch(obstacle) + + +def plot_subplot( + ax=None, + data=None, + title=None, + extent=None, + vmin=None, + vmax=None, + colormap=plot_config.plot.snap_cmap, + fontsize_title=plot_config.axes.fontsize, + fontsize=plot_config.axes.fontsize, + x_label=plot_config.axes.x_label, + y_label=plot_config.axes.y_label, + x_ticks=plot_config.axes.x_ticks, + y_ticks=plot_config.axes.y_ticks, + ticksize=plot_config.axes.ticksize, + cbar_label=None, + cbar_orientation="vertical", + add_patch=True, + errors=False, +): + """ + Plot a subplot with the given data on the provided axes. + #TODO: Remove the fontsize args, as now we are using the global rc.params + Parameters: + ax (matplotlib.axes.Axes): The axes on which to plot. + data (numpy.ndarray): The data to be plotted. + title (str): The title of the subplot. + errors (bool): Whether to use error colormap limits. Default is False. + add_patch (bool): Whether to add an obstacle patch to the plot. Default is True. + + Returns: + im (matplotlib.image.AxesImage): The image object created by imshow. + """ + + if errors: + vmin, vmax = 0, 50 + else: + vmin, vmax = vmin, vmax + + im = ax.imshow( + data.T, + cmap=colormap, + extent=extent, + origin="lower", + aspect="auto", + vmin=vmin, + vmax=vmax, + ) + + ax.set_title(title) # , fontsize=fontsize_title) + + ax.set_xlabel(x_label) # , fontsize=fontsize) + ax.set_ylabel(y_label) # , fontsize=fontsize) + + ax.set_xticks(x_ticks) + ax.set_yticks(y_ticks) + + ax.tick_params(axis="both") # , labelsize=ticksize) + + fig = ax.get_figure() + cbar = fig.colorbar(im, ax=ax, orientation=cbar_orientation) + # cbar.ax.tick_params(labelsize=ticksize) + + if cbar_label is not None: + cbar.set_label(cbar_label) # , fontsize=fontsize) + + if add_patch: + add_obstacle_patch(ax) + + if plot_config.figure.tight_layout: + plt.tight_layout() + + return im + + +def calculate_mse(test_data, pred_data, max_val): + """Calculate the Mean Squared Error (MSE) between test and predicted data.""" + error_wake = ((test_data - pred_data) ** 2 / max_val**2) * 100 + return np.mean(error_wake, axis=0), error_wake + + +def error_in_wake(gtruth_data=None, pred_data=None, mask=None, ds_ratio=None): + """ + Calculate the error and mean squared error (MSE) in the wake region of an obstacle in a flow field. + + The test and predicted data are compared to compute the error and MSE in the specified region. The wake region is masked + based on the percentage of the obstacle's height. + + Parameters: + ----------- + gtruth_data : np.ndarray Ground truth data representing the flow field. + pred_data : np.ndarray Predicted data from a model to be compared against the ground truth. + mask : float Percentage (0 to 100) of the obstacle's height that defines the wake region to be analyzed. + + Returns: + -------- + gtruth_data_wake : np.ndarray Ground truth data for the wake region. + pred_data_wake : np.ndarray Predicted data for the wake region. + error_wake : np.ndarray Element-wise squared error between test and predicted data in the wake region. + mse_wake : np.ndarray Mean squared error for the wake region. + mean_mse : np.ndarray Mean of the MSE over spatial dimensions (1 and 2). + """ + + # Obstacle dimensions (in unit coordinates) + pos_x, pos_y = -0.125, 0 + width, height = 0.25, 1 + + # Convert dimensions to pixel values + pos_x_pixel, pos_y_pixel, _ = convert_unit2pixel( + x=pos_x, y=pos_y, ds_ratio=ds_ratio + ) + width_pixel, height_pixel, _ = convert_unit2pixel( + x=width, y=height, ds_ratio=ds_ratio + ) + + # Masked wake region (based on the height mask percentage) + masked_height_pixel = int((mask / 100) * height_pixel) + + # Slice the wake region from the data + gtruth_data_wake = gtruth_data[:, :, width_pixel:, masked_height_pixel:height_pixel] + pred_data_wake = pred_data[:, :, width_pixel:, masked_height_pixel:height_pixel] + + # Calculate MSE and error for wake region + max_gtruth_data = np.max(gtruth_data) + mse_wake, error_wake = calculate_mse( + gtruth_data_wake, pred_data_wake, max_gtruth_data + ) + + # Compute the mean over spatial dimensions + mean_mse = np.mean(mse_wake, axis=(1, 2)) + + return gtruth_data_wake, pred_data_wake, error_wake, mse_wake, mean_mse + + +def plot_inst_comp_and_error( + random_indices=None, + test_data=None, + pred_data=None, + error=None, + mask=None, + comp="streamwise", + x_axis=None, + y_axis=None, + colormap="viridis", + pdf=None, + wake_region=False, + ds_ratio=None, + vmin=None, + vmax=None, +): + """ + Plot instantaneous comparison of test data, predicted data, and the error for selected indices in the specified component. + + Parameters: + ----------- + random_indices : list of int List of indices for which the comparison plots are generated. + test_data : np.ndarray Ground truth data (test data) for comparison. + pred_data : np.ndarray Model predicted data to compare with the test data. + error : np.ndarray Error data (e.g., MSE) between test and predicted data. + comp : str, optional Component to plot. Must be either 'streamwise' (default) or 'wall-normal'. + x_axis : np.ndarray Array representing the x-axis coordinates. + y_axis : np.ndarray Array representing the y-axis coordinates. + colormap : str, optional Colormap to use for the plots. Default is 'viridis'. + pdf : PdfPages object, optional If provided, the plots will be saved to this PDF. Otherwise, plots will be shown. + wake_region : bool, optional Whether to you want to compute and plot wake region specifically for same indices + + Returns: + -------- + pdf : PdfPages object or None If a PDF object is provided, the function returns the modified PDF object. Otherwise, it shows the plots and returns None. + """ + + # Check if comp is valid + if comp not in ["streamwise", "wall-normal"]: + raise ValueError("comp must be either 'streamwise' or 'wall-normal'.") + + # Select the correct channel + channel = 0 if comp == "streamwise" else 1 + + # Number of test data points + num_indices = len(random_indices) + + # Create subplots + fig, axs = plt.subplots( + num_indices, + 3, + figsize=( + 3 * plot_config.figure.figsize[0], + num_indices * plot_config.figure.figsize[1], + ), + ) + + extent = [x_axis.min(), x_axis.max(), y_axis.min(), y_axis.max()] + vmin, vmax = np.min(test_data), np.max(test_data) + + fontsize_title, fontsize = plot_config.axes.fontsize, plot_config.axes.fontsize + + if wake_region: + test_data_wake, pred_data_wake, error_wake, _, _ = error_in_wake( + gtruth_data=test_data, pred_data=pred_data, mask=mask, ds_ratio=ds_ratio + ) + + for i, num in enumerate(random_indices): + plot_subplot( + axs[i, 0], + test_data_wake[num, channel, :, :], + "", + extent, + vmin, + vmax, + colormap, + fontsize_title=fontsize_title, + fontsize=fontsize, + add_patch=False, + ) + plot_subplot( + axs[i, 1], + pred_data_wake[num, channel, :, :], + "", + extent, + vmin, + vmax, + colormap, + fontsize_title=fontsize_title, + fontsize=fontsize, + add_patch=False, + ) + plot_subplot( + axs[i, 2], + error_wake[num, channel, :, :], + "", + extent, + 0, + 50, + colormap, + fontsize_title=fontsize_title, + fontsize=fontsize, + add_patch=False, + ) + + # Loop through the data + for i, num in enumerate(random_indices): + plot_subplot( + axs[i, 0], + test_data[num, channel, :, :], + "", + extent, + vmin, + vmax, + colormap, + fontsize_title=fontsize_title, + fontsize=fontsize, + add_patch=True, + ) + plot_subplot( + axs[i, 1], + pred_data[num, channel, :, :], + "", + extent, + vmin, + vmax, + colormap, + fontsize_title=fontsize_title, + fontsize=fontsize, + add_patch=True, + ) + plot_subplot( + axs[i, 2], + error[num, channel, :, :], + "", + extent, + 0, + 40, + colormap, + fontsize_title=fontsize_title, + fontsize=fontsize, + add_patch=True, + ) + + # Save to PDF if provided + if pdf is not None: + pdf.savefig(fig) + return pdf + else: + # plt.show() + pass + # Close the figure + plt.close(fig) + + return None diff --git a/examples/weather/diffusion-urban-flows-2D/helpers/stats_helpers.py b/examples/weather/diffusion-urban-flows-2D/helpers/stats_helpers.py new file mode 100644 index 0000000000..b3ebe26e66 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/helpers/stats_helpers.py @@ -0,0 +1,1248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Statistical evaluation and visualization utilities for flow field predictions. + +This module provides comprehensive statistical evaluation tools for comparing ground truth +and predicted flow field data. It includes functionality for: +- Computing Reynolds stress components (normal and shear stresses) +- Generating joint probability density functions (PDFs) +- Computing power spectral density (PSD) using Welch's method +- Creating visual comparisons of predicted vs ground truth fields +- Plotting probe signals over time +- Supporting both 2D and 3D flow field data + +The main class `StatisticalEvaluation` provides a comprehensive framework for evaluating +diffusion model predictions from a statistical perspective. +""" + +import numpy as np +import matplotlib.pyplot as plt + +import time +import os + +from helpers.general_helpers import ( + get_data_for_stats, + add_obstacle_patch, + plot_subplot, + dict2namespace, +) +from conf.plot_configs import plot_dict + + +# basic_plt_setup() +plot_config = dict2namespace(plot_dict) + + +class StatisticalEvaluation: + """ + Comprehensive statistical evaluation class for comparing ground truth and predicted flow fields. + + This class provides an integrated framework for evaluating diffusion model predictions + from a statistical perspective. It generates detailed visualizations and analyses including: + visual field comparisons, Reynolds stress components, joint probability density functions, + and power spectral density plots. Outputs are saved as PDF and PNG files. + + The class supports both 2D and 3D flow field data and handles multiple velocity components + (u, v, and optionally w). It provides methods for computing various statistical metrics + commonly used in fluid dynamics research. + + Attributes + ---------- + gtruth : np.ndarray + Ground truth flow field data with shape (time, components, spatial_dims...). + pred : np.ndarray + Predicted flow field data with the same shape as ground truth. + x_axis : np.ndarray + X-axis coordinates for the domain. + y_axis : np.ndarray + Y-axis coordinates for the domain. + z_axis : np.ndarray, optional + Z-axis coordinates for 3D domains. + time : np.ndarray + Time steps corresponding to the temporal dimension. + input_data_type : str + Type of input data ('2D' or '3D'). + ds_ratio : float + Downsampling ratio applied to the data. + snaps : int + Number of snapshots in the prediction data. + + Notes + ----- + The class uses global plot configuration from conf.plot_configs to maintain + consistent visualization styles across all generated plots. + + Examples + -------- + >>> gtruth = np.random.rand(100, 2, 60, 20) # 100 time steps, 2 components, 60x20 spatial grid + >>> pred = np.random.rand(100, 2, 60, 20) + >>> x_axis = np.linspace(-1, 5, 60) + >>> y_axis = np.linspace(0, 2, 20) + >>> time = np.linspace(0, 10, 100) + >>> evaluator = StatisticalEvaluation(gtruth=gtruth, pred=pred, x_axis=x_axis, + ... y_axis=y_axis, time=time, input_data_type='2D') + >>> # Generate statistical evaluation plots + >>> pdf = evaluator.main(num=0, locations=[1, 2, 3], y=0.5) + """ + + def __init__( + self, + gtruth=None, + pred=None, + x_axis=None, + y_axis=None, + z_axis=None, + time=None, + input_data_type="2D", + data=None, + ds_ratio=1, + output_file_path=None, + ): + """ + Initialize the StatisticalEvaluation class with the provided parameters and evaluate the DDPM from statistical sense. + + Parameters: + gtruth (ndarray): Ground truth data. + pred (ndarray): Predicted data. + x_axis (ndarray): X-axis values. + y_axis (ndarray): Y-axis values. + z_axis (ndarray): Z-axis values. + time (ndarray): Time values. + input_data_type (str): Type of input data ('2D' or '3D'). + data (ndarray): Additional data. + config (object): Configuration object. + """ + + assert gtruth is not None and pred is not None, ( + "Ground truth and prediction must not be None" + ) + assert gtruth.any() and pred.any(), ( + "Ground truth and prediction must not be empty" + ) + + self.gtruth = gtruth + self.pred = pred + self.snaps = pred.shape[0] + + self.x_axis = x_axis + self.y_axis = y_axis + self.z_axis = z_axis + self.time = time + + self.input_data_type = input_data_type + self.data = data + self.plot_config = plot_config + + self.ds_ratio = ds_ratio + + # Obstacle dimensions & location + self.pos_x, self.pos_y = ( + self.plot_config.figure.obs_pos_x, + self.plot_config.figure.obs_pos_y, + ) + self.width, self.height = ( + self.plot_config.figure.obs_width, + self.plot_config.figure.obs_height, + ) # width, height of the obstacle + + self.min, self.max = np.min(self.gtruth), np.max(self.gtruth) + + self.eval_dir = output_file_path + + self.x_label = self.plot_config.axes.x_label + self.y_label = self.plot_config.axes.y_label + self.fontsize = self.plot_config.axes.fontsize + + # plt.rcParams['font.family'] = 'serif' + # plt.rcParams['text.usetex'] = True # Enable LaTeX rendering + + if input_data_type == "2D": + assert self.gtruth.shape[1] == self.pred.shape[1] == 2, ( + "Expected 2 velocity components in 2D data" + ) + assert len(self.gtruth.shape) == len(self.pred.shape) == 4, ( + "Expected 2D data to have 4 dimensions (time, nc ,u, v)" + ) + + elif input_data_type == "3D": + assert self.gtruth.shape[1] == self.pred.shape[1] == 3, ( + "Expected 3 velocity components in 3D data" + ) + assert len(self.gtruth.shape) == len(self.pred.shape) == 5, ( + "Expected 3D data to have 5 dimensions (time, nc ,u, v, w)" + ) + + if not os.path.exists(self.eval_dir): + os.mkdir(self.eval_dir) + + def plot_vis_compare(self, num=0, pdf=None): + """ + Plot visual comparison of ground truth and predicted data. + + Parameters: + num (int): Index for the time step to plot. + pdf (PdfPages): PDF object to save the plots. + """ + + if self.input_data_type == "2D": + fig, axs = plt.subplots( + 2, + 3, + figsize=( + 3 * self.plot_config.figure.figsize[0], + 2 * self.plot_config.figure.figsize[1], + ), + ) + + elif self.input_data_type == "3D": + fig, axs = plt.subplots( + 2, + 3, + figsize=( + 3 * self.plot_config.figure.figsize[0], + 2 * self.plot_config.figure.figsize[1], + ), + ) + + else: + raise ValueError("Unsupported input data type") + + for i in range(self.gtruth.shape[1]): + # Plot Ground Truth + + extent = [ + self.x_axis.min(), + self.x_axis.max(), + self.y_axis.min(), + self.y_axis.max(), + ] + + plot_subplot( + axs[i, 0], + self.gtruth[num, i, :, :], + extent=extent, + vmin=self.min, + vmax=self.max, + ) + plot_subplot( + axs[i, 1], + self.pred[num, i, :, :], + extent=extent, + vmin=self.min, + vmax=self.max, + ) + plot_subplot( + axs[i, 2], + self.pred[num + 1, i, :, :], + extent=extent, + vmin=self.min, + vmax=self.max, + ) + + if pdf is not None: + pdf.savefig(fig) + plt.savefig( + self.eval_dir + + f"/pred_snaps-{self.snaps}-visual_comparison_num-{num}.pdf", + dpi=self.plot_config.figure.dpi, + ) + else: + plt.savefig( + self.eval_dir + + f"/pred_snaps-{self.snaps}-visual_comparison_num-{num}.png", + dpi=self.plot_config.figure.dpi, + ) + + plt.close(fig) + + return pdf + + def reynolds_stress(self, inp=None, x=None, y=None, z=None, data=None, pdf=None): + """ + Compute Reynolds stresses for the given input data. + + Parameters: + inp (ndarray): Input data. + x (ndarray): X-axis values. + y (ndarray): Y-axis values. + z (ndarray): Z-axis values. + data (str): Data type ('line-x', 'plane', etc.). + pdf (PdfPages): PDF object to save the plots. + + Returns: + tuple: Reynolds stress components. + """ + + if self.input_data_type == "3D": + u, v, w = inp[:, 0], inp[:, 1], inp[:, 2] + assert len(u.shape) == 4, ( + "Expected 3D data to have 4 dimensions (time, x, y, z)" + ) + + elif self.input_data_type == "2D": + u, v, w = inp[:, 0], inp[:, 1], None + assert len(u.shape) == 3, ( + "Expected 2D data to have 3 dimensions (time, x, y)" + ) + + # Extract the relevant data for statistics + u_pt = get_data_for_stats( + u, + x=x, + y=y, + z=z, + input_data_type=self.input_data_type, + data=data, + ds_ratio=self.ds_ratio, + mean_over_time=False, + ) + v_pt = get_data_for_stats( + v, + x=x, + y=y, + z=z, + input_data_type=self.input_data_type, + data=data, + ds_ratio=self.ds_ratio, + mean_over_time=False, + ) + + # Compute the Reynolds stresses + rs_uu = np.mean(u_pt * u_pt, axis=0) + rs_vv = np.mean(v_pt * v_pt, axis=0) + rs_uv = np.mean(np.abs(u_pt * v_pt), axis=0) + + if w is None: + rs_ww = None + rs_vw = None + rs_uw = None + else: + w_pt = get_data_for_stats( + w, + x=x, + y=y, + z=z, + input_data_type=self.input_data_type, + data=data, + mean_over_time=False, + ) + rs_ww = np.mean(w_pt * w_pt, axis=0) + rs_vw = np.mean(v_pt * w_pt, axis=0) + rs_uw = np.mean(u_pt * w_pt, axis=0) + + return rs_uu, rs_vv, rs_ww, rs_uv, rs_vw, rs_uw + + def plot_reynolds_stresses(self, x=None, y=None, z=None, data="line-x", pdf=None): + """ + Plot Reynolds stresses for ground truth and predicted data. + + Parameters: + x (ndarray): X-axis values. + y (ndarray): Y-axis values. + z (ndarray): Z-axis values. + data (str): Data type ('line-x', 'plane', etc.). + pdf (PdfPages): PDF object to save the plots. + """ + rs_uu, rs_vv, rs_ww, rs_uv, rs_vw, rs_uw = self.reynolds_stress( + inp=self.gtruth, x=x, y=y, z=z, data=data + ) + rs_uu_pred, rs_vv_pred, rs_ww_pred, rs_uv_pred, rs_vw_pred, rs_uw_pred = ( + self.reynolds_stress(inp=self.pred, x=x, y=y, z=z, data=data) + ) + + # Check if rs_ww and rs_ww_pred are None + is_3d = rs_ww is not None and rs_ww_pred is not None + + # Plotting + fig, ax = plt.subplots( + figsize=( + 2 * self.plot_config.figure.figsize[1], + 2 * self.plot_config.figure.figsize[1], + ) + ) + + ax.plot( + self.x_axis, + rs_uu, + label=self.plot_config.axes.re_norm_stresses[0], + linestyle="-", + color="b", + ) + ax.plot( + self.x_axis, + rs_uu_pred, + label=self.plot_config.axes.re_norm_stresses_p[0], + linestyle="None", + marker="o", + color="b", + ) + + ax.plot( + self.x_axis, + rs_vv, + label=self.plot_config.axes.re_norm_stresses[1], + linestyle="-", + color="g", + ) + ax.plot( + self.x_axis, + rs_vv_pred, + label=self.plot_config.axes.re_norm_stresses_p[1], + linestyle="None", + marker="^", + color="g", + ) + + ax.plot( + self.x_axis, + rs_uv, + label=self.plot_config.axes.re_sh_stresses[0], + linestyle="-", + color="r", + ) + ax.plot( + self.x_axis, + rs_uv_pred, + label=self.plot_config.axes.re_sh_stresses_p[0], + linestyle="None", + marker="v", + color="r", + ) + + if is_3d: + ax.plot( + self.x_axis, + rs_ww, + label=self.plot_config.axes.re_norm_stresses[2], + linestyle="-", + color="c", + ) + ax.plot( + self.x_axis, + rs_ww_pred, + label=self.plot_config.axes.re_norm_stresses_p[2], + linestyle="None", + marker="s", + color="c", + ) + + ax.plot( + self.x_axis, + rs_uw, + label=self.plot_config.axes.re_sh_stresses[1], + linestyle="-", + color="m", + ) + ax.plot( + self.x_axis, + rs_uw_pred, + label=self.plot_config.axes.re_sh_stresses_p[1], + linestyle="None", + marker="d", + color="m", + ) + + ax.plot( + self.x_axis, + rs_vw, + label=self.plot_config.axes.re_sh_stresses[2], + linestyle="-", + color="y", + ) + ax.plot( + self.x_axis, + rs_vw_pred, + label=self.plot_config.axes.re_sh_stresses_p[2], + linestyle="None", + marker="*", + color="y", + ) + + ax.set_xlabel(self.x_label) # , fontsize = self.plot_config.axes.fontsize) + ax.set_ylabel( + r"$\overline{{u_i}^{\prime}{u_j}^{\prime}}$" + ) # , fontsize=self.plot_config.axes.fontsize) + ax.legend() # prop={'size': self.plot_config.axes.fontsize}) + ax.set_title( + "Reynolds Stress Components" + ) # , fontsize = self.plot_config.axes.fontsize) + + if self.plot_config.figure.tight_layout: + fig.tight_layout() + + if pdf is not None: + pdf.savefig(fig) + plt.savefig( + self.eval_dir + f"/pred_snaps-{self.snaps}-Reynolds_stresses1.pdf", + dpi=self.plot_config.figure.dpi, + ) + else: + plt.savefig( + self.eval_dir + f"/pred_snaps-{self.snaps}-Reynolds_stresses1.pdf", + dpi=self.plot_config.figure.dpi, + ) + plt.close(fig) + + return pdf + + def _plot_contour( + self, + ax, + Y_grid, + X_grid, + data, + data_pred, + title, + levels, + cmap, + fontsize, + cbar_orientation="vertical", + ): + """ + Helper method to plot a single contour plot with ground truth and predicted data. + #TODO: Move this to utils + Parameters: + ax (Axes): The matplotlib Axes object to plot on. + Y_grid (ndarray): Y-axis grid values. + X_grid (ndarray): X-axis grid values. + data (ndarray): Ground truth data. + data_pred (ndarray): Predicted data. + title (str): Title of the plot. + levels (int): Number of contour levels. + cmap (str): Colormap. + fontsize (int): Font size for the title. + cbar_orientation (str): Orientation of the colorbar ('vertical' or 'horizontal'). Default is 'vertical'. + """ + + # Create filled contour plot for ground truth + contour = ax.contourf(Y_grid, X_grid, data, levels=levels, cmap=cmap) + + # Create contour plot for predicted data in black + ax.contour(Y_grid, X_grid, data_pred, levels=levels, colors="black") + + # Add obstacle patch to the plot (custom method) + add_obstacle_patch(ax) + + # Set title and axis labels + ax.set_title(title, fontsize=fontsize) + ax.set_xlabel( + self.plot_config.axes.x_label + ) # , fontsize=self.plot_config.axes.fontsize) + ax.set_ylabel( + self.plot_config.axes.y_label + ) # , fontsize=self.plot_config.axes.fontsize) + + # Set ticks and their sizes + ax.set_xticks(self.plot_config.axes.x_ticks) + ax.set_yticks(self.plot_config.axes.y_ticks) + # ax.tick_params(axis='both', labelsize=self.plot_config.axes.ticksize) + + # Get the figure object from the axes + fig = ax.get_figure() + + # Add a colorbar with the correct orientation and tick size + fig.colorbar(contour, ax=ax, orientation=cbar_orientation) + # cbar.ax.tick_params(labelsize=self.plot_config.axes.ticksize) + + def plot_reynolds_stress_planes( + self, x=None, y=None, z=None, data="plane", pdf=None + ): + """ + Plot Reynolds stress planes for ground truth and predicted data. + + Parameters: + x (ndarray): X-axis values. + y (ndarray): Y-axis values. + z (ndarray): Z-axis values. + data (str): Data type ('plane', etc.). + levels (int): Number of contour levels. + pdf (PdfPages): PDF object to save the plots. + + Returns: + pdf (PdfPages): PDF object with the saved plots. + """ + + rs_uu, rs_vv, rs_ww, rs_uv, rs_vw, rs_uw = self.reynolds_stress( + self.gtruth, x=x, y=y, z=z, data=data + ) + rs_uu_pred, rs_vv_pred, rs_ww_pred, rs_uv_pred, rs_vw_pred, rs_uw_pred = ( + self.reynolds_stress(self.pred, x=x, y=y, z=z, data=data) + ) + + # Check if rs_ww and rs_ww_pred are None + is_3d = rs_ww is not None and rs_ww_pred is not None + + # Create mesh grid for plotting + X_grid, Y_grid = np.meshgrid(self.y_axis, self.x_axis) + + cmap = self.plot_config.plot.re_stress_cmap + fontsize = self.plot_config.axes.fontsize + levels = self.plot_config.figure.re_levels + + if not is_3d: + fig, axs = plt.subplots( + 1, + 3, + figsize=( + 2.5 * self.plot_config.figure.figsize[0], + 1 * self.plot_config.figure.figsize[1], + ), + ) + + # Plot uu + self._plot_contour( + axs[0], + Y_grid, + X_grid, + rs_uu, + rs_uu_pred, + self.plot_config.axes.re_norm_stresses[0], + levels, + cmap, + fontsize, + ) + + # Plot vv + self._plot_contour( + axs[1], + Y_grid, + X_grid, + rs_vv, + rs_vv_pred, + self.plot_config.axes.re_norm_stresses[1], + levels, + cmap, + fontsize, + ) + + # Plot uv + self._plot_contour( + axs[2], + Y_grid, + X_grid, + rs_uv, + rs_uv_pred, + self.plot_config.axes.re_sh_stresses[0], + levels, + cmap, + fontsize, + ) + + else: + fig, axs = plt.subplots( + 3, + 2, + figsize=( + 5 * self.plot_config.figure.figsize[0], + 2 * self.plot_config.figure.figsize[1], + ), + ) + # Plot uu + self._plot_contour( + axs[0, 0], + Y_grid, + X_grid, + rs_uu, + rs_uu_pred, + self.plot_config.axes.re_norm_stresses[0], + levels, + cmap, + fontsize, + ) + + # Plot vv + self._plot_contour( + axs[0, 1], + Y_grid, + X_grid, + rs_vv, + rs_vv_pred, + self.plot_config.axes.re_norm_stresses[1], + levels, + cmap, + fontsize, + ) + + # Plot ww + self._plot_contour( + axs[1, 0], + Y_grid, + X_grid, + rs_ww, + rs_ww_pred, + self.plot_config.axes.re_norm_stresses[2], + levels, + cmap, + fontsize, + ) + + # Plot uv + self._plot_contour( + axs[1, 1], + Y_grid, + X_grid, + rs_uv, + rs_uv_pred, + self.plot_config.axes.re_sh_stresses[0], + levels, + cmap, + fontsize, + ) + + # Plot uw + self._plot_contour( + axs[2, 0], + Y_grid, + X_grid, + rs_uw, + rs_uw_pred, + self.plot_config.axes.re_sh_stresses[1], + levels, + cmap, + fontsize, + ) + + # Plot vw + self._plot_contour( + axs[2, 1], + Y_grid, + X_grid, + rs_vw, + rs_vw_pred, + self.plot_config.axes.re_sh_stresses[2], + levels, + cmap, + fontsize, + ) + + if self.plot_config.figure.tight_layout: + fig.tight_layout() + + if pdf is not None: + pdf.savefig(fig) + plt.savefig( + self.eval_dir + f"/pred_snaps-{self.snaps}-Reynolds_stresses2.pdf", + dpi=self.plot_config.figure.dpi, + ) + else: + plt.savefig( + self.eval_dir + f"/pred_snaps-{self.snaps}-Reynolds_stresses2.pdf", + dpi=self.plot_config.figure.dpi, + ) + + plt.close(fig) + + return pdf + + def joint_pdfs(self, inp_comp=None, axis=None): + """ + Calculate the joint probability density function (PDF) of input components and their corresponding axis values. + + Parameters: + inp_comp (ndarray): The input components to calculate the PDF for. + axis (ndarray): The axis values corresponding to the input components. + + Returns: + xi (ndarray): The meshgrid x-values for contour plotting. + yi (ndarray): The meshgrid y-values for contour plotting. + zi_norm (ndarray): The normalized joint PDF values for the meshgrid. + """ + from scipy.stats import gaussian_kde + + u_copy = inp_comp.reshape(-1) + x_copy = ( + np.tile(axis.reshape(1, inp_comp.shape[1]), (inp_comp.shape[0], 1)) + ).reshape(-1) + + # Calculate the point density + xy = np.vstack([x_copy, u_copy]) + gaussian_kde(xy) + + # Create a grid for contour plotting + # print(u_.min()) + # xi, yi = np.linspace(axis.min(), axis.max(), 100), np.linspace(inp_comp.min(), inp_comp.max(), 100) + xi, yi = np.linspace(-1, 5, 100), np.linspace(-0.4, 0.4, 100) + xi, yi = np.meshgrid(xi, yi) + zi = gaussian_kde(xy)(np.vstack([xi.flatten(), yi.flatten()])).reshape(xi.shape) + + zi_norm = zi / np.max(zi) + + return xi, yi, zi_norm + + def plot_joint_pdfs(self, x=None, y=None, z=None, data="line-x", pdf=None): + """ + Plot the joint probability density functions (PDFs) for the components of the ground truth and predicted data. + + Parameters: + x (ndarray): The x-coordinates for data selection. + y (ndarray): The y-coordinates for data selection. + z (ndarray): The z-coordinates for data selection. + data (str): The type of data to process ('line-x', 'line-y', etc.). + levels (int): The number of contour levels to plot. + pdf (PdfPages): An optional PdfPages object to save the plots to a PDF file. + + Returns: + pdf (PdfPages): The PdfPages object if provided, with the plots saved. + """ + levels = self.plot_config.figure.jpdf_level + labels = self.plot_config.axes.fluc_label + fontsize = self.plot_config.axes.fontsize + + for i in range(self.gtruth.shape[1]): + component = get_data_for_stats( + self.gtruth[:, i], + x=x, + y=y, + z=z, + input_data_type=self.input_data_type, + data=data, + ds_ratio=self.ds_ratio, + ) + component_pred = get_data_for_stats( + self.pred[:, i], + x=x, + y=y, + z=z, + input_data_type=self.input_data_type, + data=data, + ds_ratio=self.ds_ratio, + ) + + # print("creating pdfs") + xi, yi, zi_norm = self.joint_pdfs( + component, axis=self.x_axis + ) # print("gtruth done") + xi, yi, zi_norm_pred = self.joint_pdfs( + component_pred, axis=self.x_axis + ) # print("pred done") + + # Plotting + fig, ax = plt.subplots( + figsize=( + self.plot_config.figure.figsize[0], + self.plot_config.figure.figsize[0], + ) + ) + + contour = ax.contour( + xi, yi, zi_norm, levels=levels, cmap=self.plot_config.plot.jpdf_cmap + ) + ax.contour( + xi, yi, zi_norm_pred, levels=levels, colors="black", linestyles="dotted" + ) + + # plt.title("Joint PDF of $x$ and $u'$", fontsize=self.fontsize) + ax.set_xlabel(self.plot_config.axes.x_label, fontsize=fontsize) + ax.set_ylabel(labels[i], fontsize=fontsize) + fig.colorbar(contour, ax=ax) + # ax.ylim(-0.4, 0.4) + + fig.tight_layout() + + if pdf is not None: + pdf.savefig(fig) + plt.savefig( + self.eval_dir + f"/pred_snaps-{self.snaps}-jpdfs-{i}.pdf", + dpi=self.plot_config.figure.dpi, + ) + else: + plt.savefig( + self.eval_dir + f"/pred_snaps-{self.snaps}-jpdfs-{i}.pdf", + dpi=self.plot_config.figure.dpi, + ) + + plt.close(fig) + + return pdf + + def PSD_welch(self, inp_comp=None, nperseg=256): + """ + Compute the Power Spectral Density (PSD) of an input component using the Welch method. + + Parameters: + inp_comp (ndarray): The input component for which the PSD is to be computed. + nperseg (int): Length of each segment for the Welch method (default is 256). + + Returns: + frequencies (ndarray): Array of sample frequencies. + psd (ndarray): Power spectral density of the input component. + """ + from scipy.signal import welch + + inp_comp = inp_comp.flatten() + + fs = self.time[1] - self.time[0] + + frequencies, psd = welch(inp_comp, fs=fs, nperseg=nperseg) + + return frequencies, psd + + def plot_PSD( + self, locations=None, y=None, z=None, data="point", nperseg=256, pdf=None + ): + """ + Plot the Power Spectral Density (PSD) for ground truth and predicted data using the Welch method. + + Parameters: + locations (list): List of x locations for point data. + y (float): y-coordinate for line-x data. + z (float): z-coordinate for line-x data. + data (str): Type of data, either 'point' or 'line-x'. + nperseg (int): Length of each segment for the Welch method (default is 256). + pdf (PdfPages): PdfPages object to save plots to a PDF file. + + Returns: + pdf (PdfPages): PdfPages object with saved figures. + """ + colors = ["#1f77b4", "#2ca02c", "#d62728", "#ff7f0e", "#e377c2", "#17becf"] + # Blue, Green, Red, Orange, Magenta, Cyan + labels = ["u", "v", "w"] + # Create the fig and axis + linewidth = 2 + + if data == "point": + assert self.gtruth.shape[1] == self.pred.shape[1] + for k in range(self.gtruth.shape[1]): + fig, ax = plt.subplots(figsize=(12, 8)) + + for i, x in enumerate(locations): + # Get ground truth and predicted data + comp_x = get_data_for_stats( + self.gtruth[:, k], + x=x, + y=y, + z=z, + input_data_type=self.input_data_type, + data=data, + ds_ratio=self.ds_ratio, + ) + comp_x_pred = get_data_for_stats( + self.pred[:, k], + x=x, + y=y, + z=z, + input_data_type=self.input_data_type, + data=data, + ds_ratio=self.ds_ratio, + ) + + # Compute PSD using Welch's method + frequency, psd = self.PSD_welch(comp_x, nperseg=nperseg) + frequency, psd_p = self.PSD_welch(comp_x_pred, nperseg=nperseg) + + # Plot the PSD + ax.loglog( + frequency, + psd, + label=rf"GT @ $\frac{{x}}{{h}}=$ {x}", + linestyle="-", + linewidth=linewidth, + color=colors[i], + ) + ax.loglog( + frequency, + psd_p, + label=rf"pred @ $\frac{{x}}{{h}}=$ {x}", + color=colors[i], + marker="o", + markersize=5, + linestyle="none", + ) + + # Add titles and labels + ax.set_title( + "Power Spectral Density vs Frequency (Welch Method)", + fontsize=self.fontsize, + ) + ax.set_xlabel("Frequency [Hz]", fontsize=self.fontsize) + ax.set_ylabel("PSD [V**2/Hz]", fontsize=self.fontsize) + + # Customize grid and legend + ax.grid(True, which="both", linestyle="--", linewidth=0.5) + ax.legend(fontsize=12, loc="upper right", framealpha=0.9) + + # Tight layout for better spacing + fig.tight_layout() + if pdf is not None: + pdf.savefig(fig) + else: + plt.savefig( + self.eval_dir + + f"/pred_snaps-{self.snaps}-PSD-pt-{labels[k]}.png", + dpi=self.plot_config.figure.dpi, + ) + plt.close(fig) + + elif data == "line-x": + for k in range(self.gtruth.shape[1]): + fig, ax = plt.subplots(figsize=(12, 8)) + + # Get ground truth and predicted data + comp_x = get_data_for_stats( + self.gtruth[:, k], + x=None, + y=y, + z=z, + input_data_type=self.input_data_type, + data=data, + ds_ratio=self.ds_ratio, + ) + comp_x_pred = get_data_for_stats( + self.pred[:, k], + x=None, + y=y, + z=z, + input_data_type=self.input_data_type, + data=data, + ds_ratio=self.ds_ratio, + ) + + # Compute PSD using Welch's method + frequency, psd = self.PSD_welch(comp_x, nperseg=nperseg) + frequency, psd_p = self.PSD_welch(comp_x_pred, nperseg=nperseg) + + # Plot the PSD + ax.loglog( + frequency, + psd, + label=rf"GT @ $\frac{{y}}{{h}}=$ {y}", + linestyle="-", + linewidth=linewidth, + color="k", + ) + ax.loglog( + frequency, + psd_p, + label=rf"pred @ $\frac{{y}}{{h}}=$ {y}", + color="k", + marker="o", + markersize=5, + linestyle="none", + ) + + # Add titles and labels + ax.set_title( + "Power Spectral Density vs Frequency (Welch Method)", fontsize=16 + ) + ax.set_xlabel("Frequency [Hz]", fontsize=14) + ax.set_ylabel("PSD [V**2/Hz]", fontsize=14) + + # Customize grid and legend + ax.grid(True, which="both", linestyle="--", linewidth=0.5) + ax.legend(fontsize=12, loc="upper right", framealpha=0.9) + + # Tight layout for better spacing + fig.tight_layout() + + if pdf is not None: + pdf.savefig(fig) + else: + plt.savefig( + self.eval_dir + + f"/pred_snaps-{self.snaps}-PSD_linex-{labels[k]}.png", + dpi=self.plot_config.figure.dpi, + ) + plt.close(fig) + + return pdf + + def plot_probe_signal(self, n=200, x=None, y=None, z=None): + """ + Plot time series of velocity components at a probe location. + + This method creates visualization of velocity fluctuations over time at a + specified spatial location. It compares ground truth and predicted velocity + signals side-by-side for validation. Multiple snapshots are randomly selected + for display. + + Parameters + ---------- + n : int, optional + Number of time steps to display in the plot. Default is 200. + x : float, optional + X-coordinate of the probe location in physical units. If None, data + extraction behavior depends on input_data_type. + y : float, optional + Y-coordinate of the probe location in physical units. + z : float, optional + Z-coordinate of the probe location in physical units (for 3D data). + + Returns + ------- + None + + Notes + ----- + The method plots three subplots if the data contains three velocity components + (u, v, w), or two subplots for 2D data (u, v). Each subplot shows: + - Ground truth velocity fluctuations (black line) + - Predicted velocity fluctuations (red line) + - Grid lines for easier reading + - Legend with location information + + The plot includes: + - X-axis: Number of snapshots + - Y-axis: Velocity fluctuation magnitude + - Y-axis limits: [-0.5, 0.5] for standard normalization + + Uses the internal ds_ratio attribute for coordinate conversion. + + Examples + -------- + >>> evaluator = StatisticalEvaluation(gtruth=gtruth, pred=pred, ...) + >>> evaluator.plot_probe_signal(n=200, x=1.0, y=0.5, z=None) + """ + + labels = ["u", "v", "w"] + for i in range(self.gtruth.shape[1]): + u_pt = get_data_for_stats( + self.gtruth[:, i], + x=x, + y=y, + z=z, + input_data_type="2D", + data="point", + mean_over_time=False, + ds_ratio=config.dataset.ds_ratio, + ) # Check this while integrating the function + u_pt_p = get_data_for_stats( + self.pred[:, i], + x=x, + y=y, + z=z, + input_data_type="2D", + data="point", + mean_over_time=False, + ds_ratio=config.dataset.ds_ratio, + ) + + # print(u_pt.shape); print(u_pt_p.shape) + random_integers = np.random.randint(0, u_pt.shape[0], n) + + # Plot settings + plt.figure(figsize=(12, 8)) + + plt.plot( + u_pt[random_integers], + label=rf"${labels[i]}'$ @" + r"$\frac{x}{h}= 1$ ", + linestyle="-", + color="k", + ) + plt.plot( + u_pt_p[:n], + label=rf"${labels[i]}_p'$ @" + r"$\frac{x}{h}= 1$ ", + linestyle="-", + color="r", + ) + + # Add grid lines + plt.grid(True, which="both", linestyle="--", linewidth=0.5) + + # Add titles and labels + plt.xlabel("num of snapshots", fontsize=14) + plt.ylabel(rf"${labels[i]}'$", fontsize=14) + + # Add a legend + plt.legend(loc="best", fontsize=12) + + # Add limits for x-axis to ensure both series are easily comparable + plt.xlim([0, n]) + plt.ylim([-0.5, 0.5]) + # Add a scientific look with a tighter layout + plt.tight_layout() + + # Show the plot + plt.show() + + def main(self, num=None, locations=None, y=None, pdf=None): + """ + Main function to execute various plotting and analysis routines. + + Parameters: + num (int): The index or identifier for the visual comparison (optional). + locations (list): List of x locations for PSD plotting (optional). + y (float): y-coordinate for line-x data in various plots (optional). + pdf (PdfPages): PdfPages object to save plots to a PDF file (optional). + + Returns: + pdf (PdfPages): PdfPages object with saved figures. + """ + print("Getting a visual glimps of prediction") + pdf = self.plot_vis_compare(num=0, pdf=pdf) + + print("Computing Reynolds Stresses") + pdf = self.plot_reynolds_stresses(x=None, y=0.5, z=None, pdf=pdf) + pdf = self.plot_reynolds_stress_planes( + x=None, y=None, z=None, data="plane", pdf=pdf + ) + + print("Getting pdfs for flow fields") + pdf = self.plot_joint_pdfs(x=None, y=0.5, z=None, data="line-x", pdf=pdf) + + print("Computing Power Spectral Density") + pdf = self.plot_PSD( + locations=None, y=0.5, z=None, data="line-x", nperseg=256, pdf=pdf + ) + + return pdf + + +if __name__ == "__main__": + import sys + + sys.path.append("../") + + from configs.OneObs2D_ds1_10M import config_dict + from libs import runner + from libs.utils import get_data_for_stats + + from matplotlib.backends.backend_pdf import PdfPages + + print( + "Class to evaluate statistics for given ground truth and predictions. Outputs a pdf file with relevant statistics" + ) + print("Starting dummy implementation!! ") + # Dummy data and configuration + gtruth = np.random.rand(5, 2, 60, 20) # Example ground truth data + pred = np.random.rand(5, 2, 60, 20) # Example predicted data + x_axis = np.linspace(-1, 5, 60) + y_axis = np.linspace(0, 2, 20) + time = np.linspace(0, 10, 5) # Example time array + + config = runner.dict2namespace(config_dict) + + # Instantiate the class + evaluator = StatisticalEvaluation( + gtruth=gtruth, + pred=pred, + x_axis=x_axis, + y_axis=y_axis, + z_axis=None, + time=time, + input_data_type="2D", + data="line-x", + config=config, + ) + + # Setup PDF output + pdf = PdfPages("./test_stats_eval.pdf") + pdf = evaluator.main( + num=0, locations=[0.5], y=0.5, pdf=pdf + ) # locations = locations along x, where the PSD needs to be calculated + pdf.close() diff --git a/examples/weather/diffusion-urban-flows-2D/helpers/train_helpers.py b/examples/weather/diffusion-urban-flows-2D/helpers/train_helpers.py new file mode 100644 index 0000000000..d7de41348f --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/helpers/train_helpers.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Training helper utilities for diffusion model training. + +This module provides utility functions for training diffusion models, including: +- Configuring patch shapes for input images +- Setting random seeds for reproducibility +- Configuring CUDA settings for consistent precision +- Computing gradient accumulation parameters +- Handling gradient NaN/infinity and clipping +- Parsing model arguments +- Checking timing for periodic training tasks +""" + +import torch +import numpy as np +from omegaconf import ListConfig + + +def set_patch_shape(img_shape, patch_shape): + """ + Validate and adjust patch shape for image processing in diffusion models. + + This function ensures that the patch shape is compatible with the image dimensions + and diffusion model requirements. It handles validation of patch dimensions and + applies constraints for model training. + + Parameters + ---------- + img_shape : tuple of int + Shape of the image as (height, width). + patch_shape : tuple of int or None + Desired patch shape as (height, width). Either dimension can be None to + indicate it should match the image dimension. + + Returns + ------- + tuple + A tuple of two tuples: + - validated_img_shape : tuple of int + The image shape as (height, width) + - validated_patch_shape : tuple of int + The adjusted patch shape as (height, width) + + Raises + ------ + NotImplementedError + If patch dimensions are not square (height != width) when patch is smaller + than image. Rectangular patches are not currently supported. + ValueError + If patch dimensions are not multiples of 32. This is a requirement for + compatibility with typical diffusion model architectures. + + Notes + ----- + The function applies the following logic: + 1. If any patch dimension is None or larger than the corresponding image dimension, + it is replaced with the image dimension + 2. If the patch is smaller than the image, it must be square + 3. All patch dimensions must be divisible by 32 for model compatibility + + Examples + -------- + >>> img_shape = (512, 512) + >>> patch_shape = (None, None) + >>> img_shape_out, patch_shape_out = set_patch_shape(img_shape, patch_shape) + >>> patch_shape_out + (512, 512) + + >>> img_shape = (512, 512) + >>> patch_shape = (256, 256) + >>> img_shape_out, patch_shape_out = set_patch_shape(img_shape, patch_shape) + >>> patch_shape_out + (256, 256) + + >>> img_shape = (512, 512) + >>> patch_shape = (256, 512) # Rectangular - will raise NotImplementedError + >>> set_patch_shape(img_shape, patch_shape) + """ + img_shape_y, img_shape_x = img_shape + patch_shape_y, patch_shape_x = patch_shape + if (patch_shape_x is None) or (patch_shape_x > img_shape_x): + patch_shape_x = img_shape_x + if (patch_shape_y is None) or (patch_shape_y > img_shape_y): + patch_shape_y = img_shape_y + if patch_shape_x != img_shape_x or patch_shape_y != img_shape_y: + if patch_shape_x != patch_shape_y: + raise NotImplementedError("Rectangular patch not supported yet") + if patch_shape_x % 32 != 0 or patch_shape_y % 32 != 0: + raise ValueError("Patch shape needs to be a multiple of 32") + return (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) + + +def set_seed(rank): + """ + Set seeds for NumPy and PyTorch to ensure reproducibility in distributed settings + """ + np.random.seed(rank % (1 << 31)) + torch.manual_seed(np.random.randint(1 << 31)) + + +def configure_cuda_for_consistent_precision(): + """ + Configures CUDA and cuDNN settings to ensure consistent precision by + disabling TensorFloat-32 (TF32) and reduced precision settings. + """ + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + +def compute_num_accumulation_rounds(total_batch_size, batch_size_per_gpu, world_size): + """ + Calculate the total batch size per GPU in a distributed setting, log the batch size per GPU, ensure it's within valid limits, + determine the number of accumulation rounds, and validate that the global batch size matches the expected value. + """ + batch_gpu_total = total_batch_size // world_size + batch_size_per_gpu = batch_size_per_gpu + if batch_size_per_gpu is None or batch_size_per_gpu > batch_gpu_total: + batch_size_per_gpu = batch_gpu_total + num_accumulation_rounds = batch_gpu_total // batch_size_per_gpu + if total_batch_size != batch_size_per_gpu * num_accumulation_rounds * world_size: + raise ValueError( + "total_batch_size must be equal to batch_size_per_gpu * num_accumulation_rounds * world_size" + ) + return batch_gpu_total, num_accumulation_rounds + + +def handle_and_clip_gradients(model, grad_clip_threshold=None): + """ + Handles NaNs and infinities in the gradients and optionally clips the gradients. + + Parameters: + - model (torch.nn.Module): The model whose gradients need to be processed. + - grad_clip_threshold (float, optional): The threshold for gradient clipping. If None, no clipping is performed. + """ + # Replace NaNs and infinities in gradients + for param in model.parameters(): + if param.grad is not None: + torch.nan_to_num( + param.grad, nan=0.0, posinf=1e5, neginf=-1e5, out=param.grad + ) + + # Clip gradients if a threshold is provided + if grad_clip_threshold is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_threshold) + + +def parse_model_args(args): + """Convert ListConfig values in args to tuples.""" + return {k: tuple(v) if isinstance(v, ListConfig) else v for k, v in args.items()} + + +def is_time_for_periodic_task( + cur_nimg, freq, done, batch_size, rank, rank_0_only=False +): + """Should we perform a task that is done every `freq` samples?""" + if rank_0_only and rank != 0: + return False + elif done: # Run periodic tasks also at the end of training + return True + else: + return cur_nimg % freq < batch_size + + +def is_time_for_periodic_task_epoch(epoch, freq, done, rank, rank_0_only=False): + """Should we perform a task that is done every `freq` samples?""" + if rank_0_only and rank != 0: + return False + elif done: # Run periodic tasks also at the end of training + return True + else: + return epoch % freq == 0 diff --git a/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses1.pdf b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses1.pdf new file mode 100644 index 0000000000..1e0ed45099 Binary files /dev/null and b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses1.pdf differ diff --git a/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses1.png b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses1.png new file mode 100644 index 0000000000..0b621f0ccd Binary files /dev/null and b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses1.png differ diff --git a/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses2.pdf b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses2.pdf new file mode 100644 index 0000000000..301c096502 Binary files /dev/null and b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses2.pdf differ diff --git a/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses2.png b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses2.png new file mode 100644 index 0000000000..ebf07274b3 Binary files /dev/null and b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-Reynolds_stresses2.png differ diff --git a/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-0.pdf b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-0.pdf new file mode 100644 index 0000000000..08ceffbf01 Binary files /dev/null and b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-0.pdf differ diff --git a/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-0.png b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-0.png new file mode 100644 index 0000000000..3d40942498 Binary files /dev/null and b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-0.png differ diff --git a/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-1.pdf b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-1.pdf new file mode 100644 index 0000000000..ef5484c8c4 Binary files /dev/null and b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-1.pdf differ diff --git a/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-1.png b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-1.png new file mode 100644 index 0000000000..ebfaa2a24b Binary files /dev/null and b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-jpdfs-1.png differ diff --git a/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-visual_comparison_num-0.pdf b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-visual_comparison_num-0.pdf new file mode 100644 index 0000000000..370c7d66a8 Binary files /dev/null and b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-visual_comparison_num-0.pdf differ diff --git a/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-visual_comparison_num-0.png b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-visual_comparison_num-0.png new file mode 100644 index 0000000000..dd970c84e1 Binary files /dev/null and b/examples/weather/diffusion-urban-flows-2D/results/uncond_eval/epoch-1100/pred_snaps-1000-visual_comparison_num-0.png differ diff --git a/examples/weather/diffusion-urban-flows-2D/train.py b/examples/weather/diffusion-urban-flows-2D/train.py new file mode 100644 index 0000000000..4eb451c2f7 --- /dev/null +++ b/examples/weather/diffusion-urban-flows-2D/train.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training script for diffusion models on 2D flow fields using Modulus.""" + +import os +import time +import psutil +import hydra +import torch +import tqdm + +from modulus.models.diffusion import EDMPrecond +from modulus.distributed import DistributedManager +from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper +from modulus.metrics.diffusion import EDMLoss +from modulus.launch.utils import load_checkpoint, save_checkpoint + +from omegaconf import DictConfig, OmegaConf +from torch.nn.parallel import DistributedDataParallel +from torch.utils.tensorboard import SummaryWriter + +from datasets.dataset import get_dataset_and_dataloader_from_config +from helpers.train_helpers import ( + set_seed, + configure_cuda_for_consistent_precision, + handle_and_clip_gradients, + is_time_for_periodic_task_epoch, +) + + +# Train the CorrDiff model using the configurations in "conf/config_training.yaml" +@hydra.main(version_base="1.2", config_path="conf", config_name="config_training_uflow") +def main(cfg: DictConfig) -> None: + """Train a diffusion model for 2D flow field generation. + + This function initializes the distributed training environment, sets up + the model, optimizer, loss function, and data loaders, and trains the + diffusion model using the EDM (Elucidating Diffusion Models) framework. + Supports distributed data parallel training across multiple GPUs. + + Parameters + ---------- + cfg : DictConfig + Hydra configuration containing model architecture, training + hyperparameters, dataset paths, and logging settings. + """ + # Initialize distributed environment for training + DistributedManager.initialize() + dist = DistributedManager() + + # Initialize loggers + if dist.rank == 0: + writer = SummaryWriter(log_dir="tensorboard") + logger = PythonLogger("main") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger + + # Resolve and parse configs + OmegaConf.resolve(cfg) + dataset_cfg = OmegaConf.to_container( + cfg.dataset, resolve=True + ) # TODO needs better handling + del dataset_cfg[ + "dataset_features" + ] # Because we cannot pass dataset features into dataloader, can implement better later + + if hasattr(cfg, "validation"): + OmegaConf.to_container(cfg.validation) + else: + pass + + fp_optimizations = cfg.training.perf.fp_optimizations + fp16 = fp_optimizations == "fp16" # flag to use use fp16 + enable_amp = fp_optimizations.startswith("amp") # Flag for mixed precesion + amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 + + logger.info(f"Saving the outputs in {os.getcwd()}") + + checkpoint_dir = os.path.join( + cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}" + ) + + # Set seeds and configure CUDA and cuDNN settings to ensure consistent precision + set_seed(dist.rank) + configure_cuda_for_consistent_precision() + + # Instantiate the dataset + data_loader_kwargs = { + "pin_memory": True, + "num_workers": cfg.training.perf.dataloader_workers, + "prefetch_factor": 2, + } + + dataset, DataLoader = get_dataset_and_dataloader_from_config( + dataset_cfg, + data_loader_kwargs, + batch_size=cfg.training.hp.batch_size_per_gpu, + seed=dist.rank, + shuffle=True, + dist=dist, + Train=True, + ) + + dataset_channels = dataset.num_channels() + img_shape = dataset.image_shape() + + model = EDMPrecond( + img_resolution=list(img_shape), + img_channels=dataset_channels, + model_channels=cfg.model.model_args.model_channels, + channel_mult=cfg.model.model_args.channel_mult, + attn_resolutions=cfg.model.model_args.attn_resolutions, + use_fp16=fp16, + num_blocks=cfg.model.model_args.num_blocks, + dropout=cfg.model.model_args.dropout, + model_type="SongUNet", # TODO: check if dhariwalUnet can be used + channel_mult_emb=cfg.model.model_args.channel_mult_emb, + ) + + model.train().requires_grad_(True).to(dist.device) + + # Enable distributed data parallel if applicable + if dist.world_size > 1: + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + broadcast_buffers=True, + output_device=dist.device, + find_unused_parameters=dist.find_unused_parameters, + ) + + loss_fn = EDMLoss() + + # Instantiate the optimizer + optimizer = torch.optim.Adam( + params=model.parameters(), lr=cfg.training.hp.lr, betas=[0.9, 0.999], eps=1e-8 + ) + + # Record the current time to measure the duration of subsequent operations. + start_time = time.time() + ## Resume training from previous checkpoints if exists + if dist.world_size > 1: + torch.distributed.barrier() + try: + epoch = load_checkpoint( + path=checkpoint_dir, + models=model, + optimizer=optimizer, + device=dist.device, + ) + except Exception: + epoch = 0 + + if dist.rank == 0: + total_params = sum(p.numel() for p in model.parameters()) + print(f"Number of parameters: {total_params / 1000000}M") + + ############################################################################ + # MAIN TRAINING LOOP # + ############################################################################ + + batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu + logger0.info( + f"Training for {cfg.training.hp.epochs} epochs...Starting from epoch {epoch}" + ) + done = False + + for ep in range(epoch, cfg.training.hp.epochs + 1): + tick_start_time = time.time() + loss_accum = 0 + num_batches = len(DataLoader) + pbar = tqdm.tqdm(enumerate(DataLoader), total=len(DataLoader), disable=True) + + lr_decay = cfg.training.hp.lr_decay + decay_from = cfg.training.hp.lr_decay_from + + for g in optimizer.param_groups: + # Apply learning rate decay after ramp-up + if ep >= decay_from: + g["lr"] = cfg.training.hp.lr * ( + lr_decay ** ((ep - decay_from) // decay_from) + ) + + current_lr = g["lr"] + + for batch_index, batch in pbar: + # Compute & accumulate gradients + optimizer.zero_grad(set_to_none=True) + batch = batch.to(dist.device).to(torch.float32).contiguous() + + with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp): + loss = loss_fn( + net=model, + images=batch, + augment_pipe=None, + labels=None, + ) + loss = loss.sum() / batch_size_per_gpu + loss.backward() + + # Clip gradients + handle_and_clip_gradients( + model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold + ) + + optimizer.step() + loss_accum += loss / num_batches + # Update the progress bar description with your custom message + # pbar.set_description(f"Rank:{dist.rank}, LocalRank:{dist.local_rank}, Epoch: {ep}, Batch_index: {batch_index} , Loss: {loss}") + # Done. + + loss_sum = torch.tensor([loss_accum], device=dist.device) + + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM) + average_loss = (loss_sum / dist.world_size).cpu().item() + + if dist.rank == 0: + writer.add_scalar("training_loss", average_loss, ep) + + is_time_for_periodic_task_epoch( + ep, + cfg.training.io.print_progress_freq, + done, + dist.rank, + rank_0_only=True, + ) + + done = ep >= cfg.training.hp.epochs + + if is_time_for_periodic_task_epoch( + ep, + cfg.training.io.print_progress_freq, + done, + dist.rank, + rank_0_only=True, + ): + batch_size = cfg.training.hp.batch_size_per_gpu + tick_end_time = time.time() + fields = [] + fields += [f"epoch {ep:<6}"] # Replace cur_nimg with epoch-based tracking + fields += [f"avg_training_loss {average_loss:<7.2f}"] + fields += [f"batch_size:{dist.world_size:<3.1f}x{batch_size:<3.1f}"] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [f"sec_per_epoch {(tick_end_time - tick_start_time):<7.1f}"] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + logger0.info(" ".join(fields)) + torch.cuda.reset_peak_memory_stats() + + original_args = model.module._args + original_args = OmegaConf.create(original_args) + model.module._args = OmegaConf.to_container(original_args, resolve=True) + + # Save checkpoints + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task_epoch( + ep, + cfg.training.io.save_checkpoint_freq, + done, + dist.rank, + rank_0_only=True, + ): + save_checkpoint( + path=checkpoint_dir, + models=model, + optimizer=optimizer, + epoch=ep, + ) + + logger0.info("Training Completed.") + + +if __name__ == "__main__": + main()