Skip to content

Cordiff usability and performance enhancements for custom dataset training #790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
eadd8f5
Add recent checkpoints option, adjust configs
pzharrington Jan 28, 2025
1ae9e7f
Doc for deterministic_sampler
CharlelieLrt Feb 4, 2025
d61aa08
Typo fix
CharlelieLrt Feb 5, 2025
934c1f3
Bugfix and cleanup of corrdiff regression loss and UNet
CharlelieLrt Feb 6, 2025
f120055
Minor fix in docstrings
CharlelieLrt Feb 6, 2025
a7f0836
Bugfix + doc for corrdiff regression CE loss
CharlelieLrt Feb 6, 2025
984adae
Refactor corrdiff configs for custom dataset
CharlelieLrt Feb 8, 2025
11207f7
Bugfix in configs
CharlelieLrt Feb 10, 2025
344ab6c
Added info in corrdiff docs for custom training
CharlelieLrt Feb 11, 2025
a0c59b0
Minor change in corrdiff config
CharlelieLrt Feb 11, 2025
c244e53
bring back base config file removed by mistake
CharlelieLrt Feb 11, 2025
b6a7c2d
Added config for generation on custom dataset
CharlelieLrt Feb 12, 2025
a6c40e1
Forgot some config files
CharlelieLrt Feb 12, 2025
62e6e50
Fixed overlap pixel in custom config based on discussion in PR #703
CharlelieLrt Feb 12, 2025
c1d082c
Corrdiff fixes to enable non-squared images and/or non-square patches…
CharlelieLrt Feb 12, 2025
f8a1c17
Fix small bug in config
CharlelieLrt Feb 12, 2025
d7588ac
Removed arguments redundancy in patching utilities + fixed hight-widt…
CharlelieLrt Feb 13, 2025
3d30e2a
Cleanup
CharlelieLrt Feb 14, 2025
47a054d
Added tests for rectangle images and patches
CharlelieLrt Feb 14, 2025
ddd2f4d
Added wandb logging for corrdiff training
CharlelieLrt Feb 14, 2025
fede749
Implements patching API. Refactors corrdiff train abnd generate to us…
CharlelieLrt Feb 20, 2025
0ad3c01
Corrdiff function to register new custom dataset
CharlelieLrt Feb 20, 2025
2f906da
Reorganize configs again
CharlelieLrt Feb 22, 2025
3c7f80a
Correction in configs: training duration is NOT in kilo images
CharlelieLrt Feb 24, 2025
d366de0
Readme re-write
CharlelieLrt Feb 25, 2025
b0ad80f
Merge branch 'origin/main'
CharlelieLrt Feb 25, 2025
ae4692f
Updated CHANGELOG
CharlelieLrt Feb 25, 2025
0365019
Fixed formatting
CharlelieLrt Feb 26, 2025
8dff626
Test fixes
CharlelieLrt Feb 26, 2025
a1e5f13
Typo fix
CharlelieLrt Feb 26, 2025
bee4727
Fixes on patching API
CharlelieLrt Feb 28, 2025
aa1d969
Fixed patching bug and tests
CharlelieLrt Feb 28, 2025
e799df9
Simplifications in corrdiff diffusion step
CharlelieLrt Mar 1, 2025
e871773
Forgot to propagate change to test for cordiff diffusion step
CharlelieLrt Mar 1, 2025
9cdabb8
Renamed patching API to explicit 2D
CharlelieLrt Mar 1, 2025
a57c6dc
Fixed shape in test
CharlelieLrt Mar 3, 2025
2e9ce25
Replace loops with fold/unfold patching for perf
CharlelieLrt Mar 4, 2025
2049e7e
Added method to dynamically change number of patches in RandomPatching
CharlelieLrt Mar 4, 2025
638b24e
Adds safety checks for patch shapes in patching function. Fixes tests
CharlelieLrt Mar 5, 2025
f5d3bca
Fixes docs
CharlelieLrt Mar 5, 2025
706f614
Forgot a fix in docs
CharlelieLrt Mar 5, 2025
bc49e05
New embedding selection strategy in CorrDiff UNet models
CharlelieLrt Mar 6, 2025
43cbfff
Updated CHANGELOG.md
CharlelieLrt Mar 6, 2025
deb7bec
Fixed tests for SongUNet position emneddings
CharlelieLrt Mar 7, 2025
1c70ade
More robust tests for patching
CharlelieLrt Mar 7, 2025
84f6a47
Fixed docs bug
CharlelieLrt Apr 1, 2025
e87d3f2
More bugfixes in doc tests
CharlelieLrt Apr 1, 2025
0277ae8
Merge prigin/main into cordiff-usability-enhancements-docs-bugfixes
CharlelieLrt Apr 1, 2025
278eace
Some renaming
CharlelieLrt Apr 1, 2025
626ac9c
Merge branch 'main' into cordiff-usability-enhancements-docs-bugfixes
CharlelieLrt Apr 2, 2025
9c524c5
Merge branch 'main' into cordiff-usability-enhancements-docs-bugfixes
CharlelieLrt Apr 2, 2025
16cf44a
Bugfixes, cleanup, docstrings
CharlelieLrt Apr 5, 2025
deacb6b
Merged origin/cordiff-usability-enhancements-docs-bugfixes
CharlelieLrt Apr 5, 2025
8951a84
Docstring improvement for UNet and EDMPrecondSR
CharlelieLrt Apr 7, 2025
a0f53b9
Docs for InfiniteSampler
CharlelieLrt Apr 7, 2025
015caf4
Corrected Readme info about training/generate from checkpoints
CharlelieLrt Apr 8, 2025
7db9cde
Bugfixes in generate scripts, cleanup debugging flags
CharlelieLrt Apr 8, 2025
7ad5509
Removed blank line from changelog
CharlelieLrt Apr 8, 2025
c45fc43
Fixes in CI tests
CharlelieLrt Apr 8, 2025
4e90ae8
Forgot to commit one of the CI fixes
CharlelieLrt Apr 8, 2025
7b54986
Fix example in doc
CharlelieLrt Apr 8, 2025
d38c935
Merge branch 'main' into cordiff-usability-enhancements-docs-bugfixes
CharlelieLrt Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- DrivAerML dataset support in FIGConvNet example.
- Retraining recipe for DoMINO from a pretrained model checkpoint
- Added Datacenter CFD use case.
- General purpose patching API for patch-based diffusion
- New positional embedding selection strategy for CorrDiff SongUNet models

### Changed

Expand All @@ -25,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated utils in `modulus.launch.logging` to avoid unnecessary `wandb` and `mlflow` imports
- Moved to experiment-based Hydra config in Lagrangian-MGN example
- Make data caching optional in `MeshDatapipe`
- Simplified CorrDiff config files, updated default values
- Refactored CorrDiff losses and samplers to use the patching API
- Support for non-square images and patches in patch-based diffusion

### Deprecated

Expand Down
15 changes: 13 additions & 2 deletions docs/api/modulus.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ Filesystem utils
Generative utils
----------------

.. automodule:: modulus.utils.generative.sampler
.. automodule:: modulus.utils.generative.deterministic_sampler
:members:
:show-inheritance:

.. automodule:: modulus.utils.generative.stochastic_sampler
:members:
:show-inheritance:

Expand All @@ -66,4 +70,11 @@ Weather / Climate utils
:show-inheritance:

.. automodule:: modulus.utils.zenith_angle
:show-inheritance:
:show-inheritance:

Patching utils
--------------

.. automodule:: modulus.utils.patching
:members:
:show-inheritance:
Binary file added docs/img/corrdiff_training_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
578 changes: 428 additions & 150 deletions examples/generative/corrdiff/README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

type: hrrr_mini
data_path: /data/corrdiff-mini/hrrr_mini_train.nc
stats_path: /data/corrdiff-mini/stats.json
output_variables: ['10u', '10v']
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: generation
run:
dir: ./outputs/${hydra:job.name}

# Get defaults
defaults:

# Dataset
- dataset/cwb_generate

# Sampler
- sampler/stochastic
#- sampler/deterministic

# Generation
- generation/base
#- generation/patched_based
# Dataset type. Must be overridden.
type: ???
# Path to .nc data file. Must be overridden.
data_path: ???
# Path to json stats file. Must be overriden.
stats_path: ???
# Names of input channels. Must be overridden.
input_variables: ???
# Names of output channels. Must be overridden.
output_variables: ???
# Names of invariants variables. Optional.
invariant_variables: ???
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
Expand All @@ -15,15 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Dataset type. Do not modify.
type: cwb
data_path: /code/2023-01-24-cwb-4years.zarr
# Path to data file. Must be overridden.
data_path: ???
# Indices of input channels
in_channels: [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19]
# Indices of output channels
out_channels: [0, 1, 2, 3]
# Shape of the image
img_shape_x: 448
img_shape_y: 448
# Add grid coordinates to the image
add_grid: true
# Factor to downscale the image
ds_factor: 4
# Path to min and max values of the data
min_path: null
max_path: null
# Path to global means of the data
global_means_path: null
# Path to global stds of the data
global_stds_path: null
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Dataset type. Do not modify.
type: gefs_hrrr
data_path: /data
stats_path: /data/stats.json
# Path to .nc data file. Must be overridden.
data_path: ???
# Path to json stats file. Must be overriden.
stats_path: ???
# Names of output channels.
output_variables: ["u10m", "v10m", "t2m", "precip", "cat_snow", "cat_ice", "cat_freez", "cat_rain", "cat_none"]
# Names of probability variables.
prob_variables: ["cat_snow", "cat_ice", "cat_freez", "cat_rain"]
# Names of input surface variables.
input_surface_variables: ["u10m", "v10m", "t2m", "q2m", "sp", "msl", "precipitable_water"]
# Names of input isobaric variables.
input_isobaric_variables: ['u1000', 'u925', 'u850', 'u700', 'u500', 'u250', 'v1000', 'v925', 'v850', 'v700', 'v500', 'v250', 'z1000', 'z925', 'z850', 'z700', 'z500', 'z200', 't1000', 't925', 't850', 't700', 't500', 't100', 'r1000', 'r925', 'r850', 'r700', 'r500', 'r100']
# Factor to downscale the image.
ds_factor: 4
train: False
hrrr_window: [[1,1057], [4,1796]] # need dims to be divisible by 16 [[0,1024], [0,1024]]
# Years to train the model.
train_years: [2020, 2021, 2022, 2023]
# Years to validate the model.
valid_years: [2024]
# Whether to normalize the data.
normalize: True
# Whether to shard the data.
shard: False
overfit: False
# Whether to use all the data.
use_all: False
sample_shape: [-1, -1]
hrrr_window: [[1,1057], [4,1796]] # need dims to be divisible by 16
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

name: lt_aware_ce_regression
# Name of the preconditioner
hr_mean_conditioning: False
# High-res mean (regression's output) as additional condition

# Dataset type
type: hrrr_mini
# Path to .nc data file. Must be overridden.
data_path: ???
# Path to json stats file. Must be overriden.
stats_path: ???
# Names of output channels. Must be overridden.
output_variables: ['10u', '10v']
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

num_ensembles: 64
# Number of ensembles to generate per input
seed_batch_size: 4
# Size of the batched inference
defaults:
- sampler: stochastic
# Recommended is stochastic sampler. Change to deterministic if needed.

num_ensembles: ???
# Number of ensembles to generate per input. Should be overridden.
seed_batch_size: ???
# Size of the batched inference. Should be overridden.
inference_mode: all
# Choose between "all" (regression + diffusion), "regression" or "diffusion"
patch_size: 448
patch_shape_x: 448
patch_shape_y: 448
# Patch size. Patch-based sampling will be utilized if these dimensions differ from
# img_shape_x and img_shape_y
overlap_pixels: 4
# Number of overlapping pixels between adjacent patches
boundary_pixels: 2
# Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary
# artifact.
# Choose between "all" (regression + diffusion), "regression" or "diffusion"
hr_mean_conditioning: true
gridtype: learnable
N_grid_channels: 100
sample_res: full
# Sampling resolution
times_range: null
times:
- 2021-02-02T00:00:00
- 2021-03-02T00:00:00
- 2021-04-02T00:00:00
# hurricane
- 2021-09-12T00:00:00
- 2021-09-12T12:00:00
# Whether to use hr_mean_conditioning
times_range: ???
# Time range to generate. Should be overridden.
has_lead_time: False
# Whether the model has lead time.

perf:
force_fp16: false
Expand All @@ -55,9 +42,3 @@ perf:
num_writer_workers: 1
# number of workers to use for writing file
# To support multiple workers a threadsafe version of the netCDF library must be used

io:
res_ckpt_filename: diffusion_checkpoint.mdlus
# Checkpoint filename for the diffusion model
reg_ckpt_filename: regression_checkpoint.mdlus
# Checkpoint filename for the mean predictor model
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
# limitations under the License.

defaults:
- corrdiff_regression
- base_all

model_args:
model_channels: 64
channel_mult: [1, 2, 2]
attn_resolutions: [16]
patching: False
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: gefs_hrrr_regression
run:
dir: ./outputs/${hydra:job.name}

# Get defaults
defaults:
- base_all

# Dataset
- dataset/gefs_hrrr

# Model
- model/corrdiff_regression_gefs_hrrr

# Training
- training/corrdiff_regression_gefs_hrrr
patching: True
# Use patch-based sampling
overlap_pix: 4
# Number of overlapping pixels between adjacent patches
boundary_pix: 2
# Number of boundary pixels to be cropped out. 2 is recommended to address the boundary
# artifact.
patch_shape_x: ???
patch_shape_y: ???
# Patch size. Patch-based sampling will be utilized if these dimensions
# differ from img_shape_x and img_shape_y. Needs to be overridden.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# @package _global_.sampler

type: deterministic
num_steps: 9
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# @package _global_.sampler

type: stochastic
boundary_pix: 2
overlap_pix: 4
#overlap_pix has to be no less than 2*boundary_pix
37 changes: 37 additions & 0 deletions examples/generative/corrdiff/conf/base/model/diffusion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: diffusion
# Model type.
hr_mean_conditioning: True
# Recommended to use high-res conditioning for diffusion.
scale_cond_input: False
# If true, also scales the input conditioning. Recommended to False.

# Standard model parameters.
model_args:
gridtype: "sinusoidal"
# Type of positional grid to use: 'sinusoidal', 'learnable', 'linear'.
# Controls how positional information is encoded.
N_grid_channels: 4
# Number of channels for positional grid embeddings
embedding_type: "zero"
# Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++,
# 'zero' for none
model_type: "SongUNetPosEmbd"
# Type of model architecture: 'SongUNetPosLtEmbd' for lead-time aware UNet
# with positional embeddings, 'SongUNetPosEmbd' for UNet with positional
# embeddings, 'DhariwalUNet' for UNet with Fourier embeddings
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


name: lt_aware_ce_regression
# Model type.
hr_mean_conditioning: False
# No high-res conditioning for regression.

# Default model parameters.
model_args:
img_channels: 4
# Number of color channels in the model
N_grid_channels: 4
# Number of channels for positional grid embeddings
embedding_type: "zero"
# Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++,
# 'zero' for none
lead_time_channels: 4
# Number of channels for lead-time embeddings
lead_time_steps: 9
# Number of lead-time steps
model_type: "SongUNetPosLtEmbd"
# Type of model architecture: 'SongUNetPosLtEmbd' for lead-time aware UNet with
# positional embeddings, 'SongUNetPosEmbd' for UNet with positional embeddings,
# 'DhariwalUNet' for UNet with Fourier embeddings
Loading