- Generative Correction Diffusion Model (CorrDiff) for Km-scale Atmospheric Downscaling
To improve weather hazard predictions without expensive simulations, a cost-effective stochastic downscaling model, CorrDiff, is trained using high-resolution weather data and coarser ERA5 reanalysis. CorrDiff employs a two-step approach with UNet and diffusion to address multi-scale challenges, showing strong performance in predicting weather extremes and accurately capturing multivariate relationships like intense rainfall and typhoon dynamics, suggesting a promising future for global-to-km-scale machine learning weather forecasts.
To get started with CorrDiff, we provide a simplified version called CorrDiff-Mini that combines:
- A smaller neural network architecture that reduces memory usage and training time
- A reduced training dataset, based on the HRRR dataset, that contains fewer samples (available at NGC)
Together, these modifications reduce training time from thousands of GPU hours to around 10 hours on A100 GPUs. The simplified data loader included with CorrDiff-Mini also serves as a helpful example for training CorrDiff on custom datasets. Note that CorrDiff-Mini is intended for learning and educational purposes only - its predictions should not be used for real applications.
Start by installing PhysicsNeMo (if not already installed) and copying this folder (examples/generative/corrdiff
) to a system with a GPU available. Also download the CorrDiff-Mini dataset from NGC.
CorrDiff training is managed through train.py
and uses YAML configuration files powered by Hydra. The configuration system is organized as follows:
- Base Configurations: Located in the
conf/base
directory - Configuration Files:
- Training Configurations:
- GEFS-HRRR dataset (continental United States):
conf/config_training_gefs_hrrr_regression.yaml
- Configuration for training the regression model on GEFS-HRRR datasetconf/config_training_gefs_hrrr_diffusion.yaml
- Configuration for training the diffusion model on GEFS-HRRR dataset
- HRRR-Mini dataset (smaller continental United States,):
conf/config_training_hrrr_mini_regression.yaml
- Simplified regression model training setup for the HRRR-Mini exampleconf/config_training_hrrr_mini_diffusion.yaml
- Simplified diffusion model training setup for the HRRR-Mini example
- Taiwan dataset:
conf/config_training_taiwan_regression.yaml
- Configuration for training the regression model on Taiwan weather dataconf/config_training_taiwan_diffusion.yaml
- Configuration for training the diffusion model on Taiwan weather data
- Custom dataset:
conf/config_training_custom.yaml
- Template configuration for training on custom datasets
- GEFS-HRRR dataset (continental United States):
- Generation Configurations:
conf/config_generate_taiwan.yaml
- Settings for generating predictions using Taiwan-trained modelsconf/config_generate_hrrr_mini.yaml
- Settings for generating predictions using HRRR-Mini modelsconf/config_generate_gefs_hrrr.yaml
- Settings for generating predictions using GEFS-HRRR modelsconf/config_generate_custom.yaml
- Template configuration for generation with custom trained models
- Training Configurations:
To select a specific configuration, use the --config-name
option when running the training script. Each training configuration file defines three main components:
- Training dataset parameters
- Model architecture settings
- Training hyperparameters
You can modify configuration options in two ways:
- Direct Editing: Modify the YAML files directly
- Command Line Override: Use Hydra's
++
syntax to override settings at runtime
For example, to change the training batch size (controlled by training.hp.total_batch_size
):
python train.py ++training.hp.total_batch_size=64 # Sets batch size to 64
This modular configuration system allows for flexible experimentation while maintaining reproducibility.
CorrDiff uses a two-step training process:
- Train a deterministic regression model
- Train a diffusion model using the pre-trained regression model
For the CorrDiff-Mini regression model, we use the following configuration components:
The top-level configuration file config_training_hrrr_mini_regression.yaml
contains the most commonly modified parameters:
dataset
: Dataset type and paths (hrrr_mini
,gefs_hrrr
,cwb
, orcustom
)model
: Model architecture type (regression
,diffusion
, etc.)model_size
: Model capacity (normal
ormini
for faster experiments)training
: High-level training parameters (duration, batch size, IO settings)wandb
: Weights & Biases logging settings (mode
,results_dir
,watch_model
)
This configuration automatically loads these specific files from conf/base
:
dataset/hrrr_mini.yaml
: HRRR-Mini dataset parameters (data paths, variables)model/regression.yaml
: Regression UNet architecture settingsmodel_size/mini.yaml
: Reduced model capacity settings for faster trainingtraining/regression.yaml
: Training loop parameters specific to regression model
These base configuration files contain more detailed settings that are less commonly modified but give fine-grained control over the training process.
To begin training, execute the following command using train.py
:
python train.py --config-name=config_training_hrrr_mini_regression.yaml
Training Details:
- Duration: A few hours on a single A100 GPU
- Checkpointing: Automatically resumes from latest checkpoint if interrupted
- Multi-GPU Support: Compatible with
torchrun
or MPI for distributed training
💡 Memory Management
The default configuration uses a batch size of 256 (controlled bytraining.hp.total_batch_size
). If you encounter memory constraints, particularly on GPUs with limited memory, you can reduce the per-GPU batch size by setting++training.hp.batch_size_per_gpu=64
. CorrDiff will automatically employ gradient accumulation to maintain the desired effective batch size while using less memory.
After successfully training the regression model, you can proceed with training the diffusion model. The process requires:
- A pre-trained regression model checkpoint
- The same dataset used for regression training
- Configuration file
conf/config_training_hrrr_mini_diffusion.yaml
To start the diffusion model training, execute:
python train.py --config-name=config_training_hrrr_mini_diffusion.yaml \
++training.io.regression_checkpoint_path=</path/to/regression/model>
where </path/to/regression/model>
should point to the saved regression checkpoint.
The training will generate checkpoints in the checkpoints_diffusion
directory. Upon completion, the final model will be saved as EDMPrecondSR.0.8000000.mdlus
.
Once both models are trained, you can use generate.py
to create new predictions. The generation process requires:
Required Files:
- Trained regression model checkpoint
- Trained diffusion model checkpoint
- Configuration file
conf/config_generate_hrrr_mini.yaml
Execute the generation command:
python generate.py --config-name="config_generate_hrrr_mini.yaml" \
++generation.io.res_ckpt_filename=</path/to/diffusion/model> \
++generation.io.reg_ckpt_filename=</path/to/regression/model>
The output is saved as a NetCDF4 file containing three groups:
input
: The original input datatruth
: The ground truth data for comparisonprediction
: The CorrDiff model predictions
You can analyze the results using the Python NetCDF4 library or visualization tools of your choice.
The Taiwan example demonstrates CorrDiff training on a high-resolution weather dataset conditioned on the low-resolution ERA5 dataset. This dataset is available for non-commercial use under the CC BY-NC-ND 4.0 license.
Dataset Access:
- Location: NGC Catalog - CWA Dataset
- Download Command:
ngc registry resource download-version "nvidia/modulus/modulus_datasets_cwa:v1"
The Taiwan example supports three types of models, each serving a different purpose:
- Regression Model: Basic deterministic model
- Diffusion Model: Full probabilistic model
- Patch-based Diffusion Model: Memory-efficient variant that processes small spatial regions to improve scalability
The patch-based approach divides the target region into smaller subsets during both training and generation, making it particularly useful for memory-constrained environments or large spatial domains.
Configuration Structure:
The top-level configuration file config_training_taiwan_regression.yaml
contains commonly modified parameters:
dataset
: Set tocwb
for the Taiwan Central Weather Bureau datasetmodel
: Model type (regression
,diffusion
, orpatched_diffusion
)model_size
: Model capacity (normal
recommended for Taiwan dataset)training.hp
: Training duration and batch size settingswandb
: Experiment tracking configuration
This configuration automatically loads these specific files from conf/base
:
dataset/cwb.yaml
: Taiwan dataset parametersmodel/regression.yaml
ormodel/diffusion.yaml
: Model architecture settingstraining/regression.yaml
ortraining/diffusion.yaml
: Training parameters
When training the diffusion variants, you'll need to specify the path to your pre-trained regression checkpoint in training.io.regression_checkpoint_path
. This is essential as the diffusion model learns to predict residuals on top of the regression model's predictions.
Training Commands:
For single-GPU training:
python train.py --config-name=config_training_taiwan_regression.yaml
For multi-GPU or multi-node training:
torchrun --standalone --nnodes=<NUM_NODES> --nproc_per_node=<NUM_GPUS_PER_NODE> train.py
To switch between model types, simply change the configuration name in the training command (e.g., config_training_taiwan_diffusion.yaml
for the diffusion model).
The evaluation pipeline for CorrDiff models consists of two main components:
-
Sample Generation (
generate.py
): Generates predictions and saves them in a netCDF file format. The process uses configuration settings fromconf/config_generate.yaml
.python generate.py --config-name=config_generate_taiwan.yaml
-
Performance Scoring (
score_samples.py
): Computes both deterministic metrics (like MSE, MAE) and probabilistic scores for the generated samples.python score_samples.py path=<PATH_TO_NC_FILE> output=<OUTPUT_FILE>
For visualization and analysis, you have several options:
- Use the plotting scripts in the
inference
directory - Visualize results with Earth2Studio
- Create custom visualizations using the NetCDF4 output structure
CorrDiff supports two powerful tools for experiment tracking and visualization:
TensorBoard Integration: TensorBoard provides real-time monitoring of training metrics when running in a Docker container:
-
Configure Docker:
docker run -p 6006:6006 ... # Include port forwarding
-
Start TensorBoard:
tensorboard --logdir=/path/to/logdir --port=6006
-
Set up SSH tunnel (for remote servers):
ssh -L 6006:localhost:6006 <user>@<remote-server-ip>
-
Access the dashboard at
http://localhost:6006
Weights & Biases Integration: CorrDiff includes integration with Weights & Biases for experiment tracking. The following parameters are hardcoded in the code:
- Project name: "Modulus-Launch"
- Entity: "Modulus"
- Run name: Generated based on configuration job name
- Group: "CorrDiff-DDP-Group"
You can configure the following wandb parameters in the configuration files:
wandb:
mode: offline # Options: "online", "offline", "disabled"
results_dir: "./wandb" # Directory to store wandb results
watch_model: true # Whether to track model parameters and gradients
To use wandb:
-
Initialize wandb (first time only):
wandb login
-
Training runs will automatically log to the wandb project, tracking:
- Training and validation metrics
- Model architecture details
- System resource usage
- Hyperparameters
You can access your experiment dashboard at Weights & Biases website.
This repository includes examples of CorrDiff training on specific datasets, such as Taiwan and HRRR. However, many use cases require training CorrDiff on a custom high-resolution dataset. The steps below outline the process.
To train CorrDiff on a custom dataset, you need to implement a custom dataset class that inherits from DownscalingDataset
defined in datasets/base.py
. This base class defines the interface that all dataset implementations must follow.
Required Implementation:
- Your dataset class must inherit from
DownscalingDataset
and implement its abstract methods, for example:longitude()
andlatitude()
: Return coordinate arraysinput_channels()
andoutput_channels()
: Define metadata for input/output variablestime()
: Return time valuesimage_shape()
: Return spatial dimensions__len__()
: Return total number of samples__getitem__()
: Return data for a given index
The most important method is __getitem__
, which must return a tuple of tensors:
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Returns:
Tuple containing:
- img_clean: Target high-resolution data [output_channels, height, width]
- img_lr: Input low-resolution data [input_channels, height, width]
- lead_time_label: (Optional) Lead time information [1]
"""
# Your implementation here
# For basic implementation without lead time:
return img_clean, img_lr
# If including lead time information:
# return img_clean, img_lr, lead_time_label
- Configure your dataset in the YAML configuration file. Any parameters below
will be passed to your dataset's
__init__
method. For example:dataset: type: path/to/your/dataset.py::CustomDataset # Path to file::class name # All parameters below will be passed to your dataset's __init__ data_path: /path/to/your/data stats_path: /path/to/statistics.json # Optional normalization stats input_variables: ["temperature", "pressure"] # Example parameters output_variables: ["high_res_temperature"] invariant_variables: ["topography"] # Optional static fields # Add any other parameters needed by your dataset class
Important Notes:
- The training script will automatically:
- Parse the
type
field to locate your dataset file and class - Register your custom dataset class using
register_dataset()
- Pass all other fields in the
dataset
section as kwargs to your class constructor
- Parse the
- All tensors should be properly normalized (use
normalize_input
/normalize_output
methods if needed) - Ensure consistent dimensions across all samples
- Channel metadata should accurately describe your data variables
For reference implementations of dataset classes, look at:
datasets/hrrrmini.py
- Simple example using NetCDF formatdatasets/cwb.py
- More complex example
After implementing your custom dataset, you can proceed with the two-step training process followed by generation. The configuration system uses a hierarchical structure that balances ease of use with detailed control over the training process.
Top-level Configuration (config_training_custom.yaml
):
This file serves as your primary interface for configuring the training process. It contains commonly modified parameters that can be set either directly in the file or through command-line overrides:
dataset
: Configuration for your custom dataset implementation, including paths and variablesmodel
: Core model settings, including type selection (regression
ordiffusion
)training
: High-level training parameters like batch size and durationwandb
: Weights & Biases logging settings (mode
,results_dir
,watch_model
)
Fine-grained Control:
The base configuration files in conf/base/
provide detailed control over specific components. These files are automatically loaded based on your top-level choices:
model/*.yaml
: Contains architecture-specific settings for network depth, attention mechanisms, and embedding configurationstraining/*.yaml
: Defines training loop behavior, including optimizer settings and checkpoint frequencymodel_size/*.yaml
: Provides preset configurations for different model capacities
While direct modification of these base files is typically unnecessary, any
parameter can be overridden using Hydra's ++
syntax. For example, to reduce
the learning rate to 0.0001:
python train.py --config-name=config_training_custom.yaml ++training.hp.lr=0.0001
This configuration system allows you to start with sensible defaults while maintaining the flexibility to customize any aspect of the training process.
You can directly modify the training configuration file to change the dataset,
model, and training parameters, or use Hydra's ++
syntax to override
them. Once the regression model is trained, proceed with training the diffusion
model. During training, you can fine-tune various parameters. The most commonly adjusted parameters include:
training.hp.total_batch_size
: Controls the total batch size across all GPUstraining.hp.batch_size_per_gpu
: Adjusts per-GPU memory usagetraining.hp.patch_shape_x/y
: Sets dimensions for patch-based trainingtraining.hp.training_duration
: Defines total training stepstraining.hp.lr_rampup
: Controls learning rate warmup period
Starting with a Small Model
When developing a new dataset implementation, it is recommended to start with a smaller model for faster iteration and debugging. You can do this by settingmodel_size: mini
in your configuration file:defaults: - model_size: mini # Use smaller architecture for testingThis is similar to the model used in the HRRR-Mini example and can significantly reduce testing time. After debugging, you can switch back to the full model by setting the
model_size
setting tonormal
.
Note on Patch Size Selection
When implementing a patch-based training, choosing the right patch size is critical for model performance. The patch dimensions are controlled bypatch_shape_x
andpatch_shape_y
in your configuration file. To determine optimal patch sizes:
- Calculate the auto-correlation function of your data using the provided utilities in
inference/power_spectra.py
:
average_power_spectrum()
power_spectra_to_acf()
- Set patch dimensions to match or exceed the distance at which auto-correlation approaches zero
- This ensures each patch captures the full spatial correlation structure of your data
This analysis helps balance computational efficiency with the preservation of important physical relationships in your data.
After training both models successfully, you can use CorrDiff's generation pipeline to create predictions. The generation system follows a similar hierarchical configuration structure as training.
Top-level Configuration (config_generate_custom.yaml
):
This file serves as the main interface for controlling the generation process. It defines essential parameters that can be modified either directly or through command-line overrides.
Fine-grained Control:
The base configuration files in conf/base/generation
provide fine-grained control over
the generation process. For example, sampling/stochastic.yaml
controls the
stochastic sampling process (noise scheduling, number of sampling steps,
classifier-free guidance settings). While these base configurations are typically used as-is, you can override any
parameter directly in the configuration file or using Hydra's ++
syntax. For
example to increase the number of ensembles generated per input, you can run:
python generate.py --config-name=config_generate_custom.yaml \
++generation.io.res_ckpt_filename=/path/to/diffusion/checkpoint.mdlus \
++generation.io.reg_ckpt_filename=/path/to/regression/checkpoint.mdlus \
++dataset.type=path/to/your/dataset.py::CustomDataset \
++generation.num_ensembles=10
Key generation parameters that can be adjusted include for example:
generation.num_ensembles
: Number of samples to generate per inputgeneration.patch_shape_x/y
: Patch dimensions for patch-based generation
The generated samples are saved in a NetCDF file with three main components:
- Input data: The original low-resolution inputs
- Ground truth: The actual high-resolution data (if available)
- Predictions: The model-generated high-resolution outputs
-
Are there pre-trained checkpoints available and when should they be used for training/inference?
Pre-trained checkpoints are available through NVIDIA AI Enterprise. For example, a trained model for the continental United States on the GEFS-HRRR dataset is available here. However, note that these checkpoints are not necessarily compatible with the current implementation oftrain.py
andgenerate.py
in CorrDiff. Typically, these checkpoints should only be used for inference in Earth2Studio. It is therefore generally recommended to start training CorrDiff models from a scratch. If you do have a checkpoint compatible with the current implementation oftrain.py
andgenerate.py
(e.g. from one of your own previous training run), it is recommended to restart training from your checkpoint rather than from scratch if the following conditions are met:- Your custom dataset covers a region included in the training data of the checkpoint (e.g., a sub-region of the continental United States for the checkpoint mentioned above).
- At most half of the variables in your dataset are also included in the training data of the checkpoint.
Training from scratch is recommended for all other cases.
-
How many samples are needed to train a CorrDiff model?
The more, the better. As a rule of thumb, at least 50,000 samples are necessary.
Note: For patch-based diffusion, each patch can be counted as a sample. -
How many GPUs are required to train CorrDiff?
A single GPU is sufficient as long as memory is not exhausted, but this may result in extremely slow training. To accelerate training, CorrDiff leverages distributed data parallelism. The total training wall-clock time roughly decreases linearly with the number of GPUs. Most CorrDiff training examples have been conducted with 64 A100 GPUs. If you encounter an out-of-memory error, reducebatch_size_per_gpu
or, for patch-based diffusion models, decrease the patch size—ensuring it remains larger than the auto-correlation distance. -
How long does it take to train CorrDiff on a custom dataset?
Training CorrDiff on the continental United States dataset required approximately 5,000 A100 GPU hours. This corresponds to roughly 80 hours of wall-clock time with 64 GPUs. You can expect the cost to scale linearly with the number of samples available. -
What are CorrDiff's current limitations for custom datasets?
The main limitation of CorrDiff is the maximum downscaling ratio it can achieve. For a purely spatial super-resolution task (where input and output variables are the same), CorrDiff can reliably achieve a maximum resolution scaling of ×16. If the task involves inferring new output variables, the maximum reliable spatial super-resolution is ×11. -
What does a successful training look like?
In a successful training run, the loss function should decrease monotonically, as shown below:
-
Which hyperparameters are most important?
One of the most crucial hyperparameters is the patch size for a patch-based diffusion model (patch_shape_x
andpatch_shape_y
in the configuration file). A larger patch size increases computational cost and GPU memory requirements, while a smaller patch size may lead to a loss of physical information. The patch size should not be smaller than the auto-correlation distance, which can be determined using the auto-correlation plotting utility. Other important hyperparameters include:- Training duration (
training.hp.training_duration
): Total number of samples to process during training. Values between 1M and 30M samples are typical, depending on the size of the dataset and on the type of model (regression or diffusion). - Learning rate ramp-up (
training.hp.lr_rampup
): Number of samples over which learning rate gradually increases. In some cases,lr_rampup=0
is sufficient, but if training is unstable, it may be necessary to increase it. Values between 0 and 200M samples are typical. - Learning rate (
training.hp.lr
): Base learning rate that controls how quickly model parameters are updated. It may be decreased if training is unstable, and increased if training is slow. - Batch size per GPU (
training.hp.batch_size_per_gpu
): Number of samples processed in parallel on each GPU. It needs to be reduced if you encounter an out-of-memory error.
- Training duration (