Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various Corrdiff optimizations for drastic increase of training efficiency #809

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- DrivAerML dataset support in FIGConvNet example.
- Retraining recipe for DoMINO from a pretrained model checkpoint
- Added Datacenter CFD use case.
- Added `ResLoss_Opt` for patch amortized CorrDiff training

### Changed

Expand All @@ -25,6 +26,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated utils in `modulus.launch.logging` to avoid unnecessary `wandb` and `mlflow` imports
- Moved to experiment-based Hydra config in Lagrangian-MGN example
- Make data caching optional in `MeshDatapipe`
- Updated CorrDiff training code to support multiple patch iterations
to amortize regression cost and usage of `torch.compile`
- Refactored `modulus/models/diffusion/layers.py` to optimize data type casting workflow,
avoiding unnecessary casting under autocast mode
- Refactored Conv2d to enable fusion of conv2d with bias addition
- Refactored GroupNorm, UNetBlock, SongUNet, SongUNetPosEmbd to support usage of
Apex GroupNorm, fusion of activation with GroupNorm, and AMP workflow.
- Updated SongUNetPosEmbd to avoid unnecessary HtoD Memcpy of `pos_embd`
- Updated `from_checkpoint` to accommodate usage of Apex GroupNorm
- Refactored CorrDiff NVTX annotation workflow to be configurable

### Deprecated

Expand Down
13 changes: 8 additions & 5 deletions examples/generative/corrdiff/conf/config_training.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@ hydra:
defaults:

# Dataset
- dataset/cwb_train
# - dataset/cwb_train
- dataset/hrrr_corrdiff_synthetic

# Model
- model/corrdiff_regression
# - model/corrdiff_regression
#- model/corrdiff_diffusion
#- model/corrdiff_patched_diffusion
- model/corrdiff_patched_diffusion


# Training
- training/corrdiff_regression
# - training/corrdiff_regression
#- training/corrdiff_diffusion
#- training/corrdiff_patched_diffusion
# - training/corrdiff_patched_diffusion
- training/corrdiff_patched_diffusion_opt

# Validation (comment out to disable validation)
- validation/cwb
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Hyperparameters
hp:

training_duration: 200000000
# Training duration
total_batch_size: 16
# Total batch size
batch_size_per_gpu: 4
# Limit batch size per GPU
cbase: null # TODO check
# Channel multiplier
cres: 1 # TODO check
# Channels per resolution

ema: 0.5
# EMA half-life
dropout: 0.13
# Dropout probability
augment: 0.0
# Augment probability
# hr_mean_conditioning: False
# High-res mean (regression's output) as additional condition
# gridtype: "sinusoidal"
# can be either linear, sinusoidal, or learnable
# N_grid_channels: 4

lr: 0.0002
# Learning rate
grad_clip_threshold: 1e5
lr_decay: 0.7
lr_rampup: 10000



patch_shape_x: 448
patch_shape_y: 448
# Patch size. Patch-based training will be utilized if these dimensions differ from
# img_shape_x and img_shape_y
patch_num: 16
max_patch_per_gpu: 9
# Number of patches extracted from a single sample. The total number of patches is
# defined as patch_num * batch_size_global
hr_mean_conditioning: True
# High-res mean (regression's output) as additional condition
gridtype: "learnable"
# can be either linear, sinusoidal, or learnable
N_grid_channels: 100

P_mean: -1.2
P_std: 1.2
sigma_data: 0.5


# Performance
perf:
fp_optimizations: amp-bf16
# Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
# "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
dataloader_workers: 4
# DataLoader worker processes
songunet_checkpoint_level: 0 # 0 means no checkpointing
# Gradient checkpointing level, value is number of layers to checkpoint
# optimization_mode: True
use_apex_gn: True
torch_compile: True
profile_mode: False


# I/O
io:
regression_checkpoint_path: /lustre/fsw/portfolios/coreai/users/asui/video-corrdiff-checkpoints/training-state-regression-000513.mdlus


# Where to load the regression checkpoint
print_progress_freq: 1000
# How often to print progress
save_checkpoint_freq: 500000
# How often to save the checkpoints, measured in number of processed samples
validation_freq: 5000
# how often to record the validation loss, measured in number of processed samples
validation_steps: 10
# how many loss evaluations are used to compute the validation loss per checkpoint
Loading