Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
235 commits
Select commit Hold shift + click to select a range
46d2880
Move filesystems and version_check to core
coreyjadams Nov 3, 2025
c6d04ad
Fix version check tests
coreyjadams Nov 3, 2025
6f36f03
Reorganize distributed, domain_parallel, and begin nn / utils cleanup.
coreyjadams Nov 3, 2025
7824091
Move modules and meta to core. Move registry to core.
coreyjadams Nov 3, 2025
f753573
Add missing init files
coreyjadams Nov 3, 2025
2ef835e
Update build system and specify some deps.
coreyjadams Nov 3, 2025
1603067
Merge branch 'main' into refactor
coreyjadams Nov 3, 2025
1e8df52
Reorganize tests.
coreyjadams Nov 3, 2025
2e1195c
Update init files
coreyjadams Nov 3, 2025
a698685
Clean up neighbor tools.
coreyjadams Nov 3, 2025
258d988
Update testing
coreyjadams Nov 3, 2025
0638b97
Fix compat tests
coreyjadams Nov 3, 2025
b6327cb
Move core model tests to tests/core/
coreyjadams Nov 3, 2025
3ce049a
Add import lint config
coreyjadams Nov 3, 2025
95fa450
Relocate layers
coreyjadams Nov 3, 2025
ba6813d
Move graphcast utils into model directory
coreyjadams Nov 3, 2025
3f10463
Relocating util functionalities.
coreyjadams Nov 4, 2025
339b484
Further clean up and organize tests.
coreyjadams Nov 5, 2025
18df402
Merge branch 'NVIDIA:main' into refactor
coreyjadams Nov 5, 2025
d6946d9
utils tests are passing now
coreyjadams Nov 5, 2025
66f8d15
Cleaning up distributed tests
coreyjadams Nov 5, 2025
2ee76db
Patching tests working again in nn
coreyjadams Nov 5, 2025
33d525d
Fix sdf test
coreyjadams Nov 5, 2025
a06ad0a
Fix zenith angle tests
coreyjadams Nov 5, 2025
4c845cc
Some organization of tests. Checkpoints is moved into utils.
coreyjadams Nov 5, 2025
3bb64f4
Remove launch.utils and launch.config. Checkpointing is moved to
coreyjadams Nov 5, 2025
4aa332e
Most nn tests are passing
coreyjadams Nov 5, 2025
45686cc
Further cleanup. Getting there!
coreyjadams Nov 5, 2025
bbc54f6
Remove constants file
coreyjadams Nov 5, 2025
8453fea
Add import linting to pre-commit.
coreyjadams Nov 5, 2025
7ff2a2a
Refactor (#1208)
coreyjadams Nov 5, 2025
f850488
Merge branch 'main' into refactor
coreyjadams Nov 5, 2025
1c5f91c
Merge branch 'main' into v2.0-refactor
coreyjadams Nov 5, 2025
21343f5
Unmigrate the insolation utils (#1211)
pzharrington Nov 6, 2025
337c91e
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 7, 2025
4583c42
Move gnn layers and start to fix several model tests.
coreyjadams Nov 7, 2025
e326d4a
AFNO is now passing.
coreyjadams Nov 7, 2025
b95097d
Rnn models passing.
coreyjadams Nov 7, 2025
d8bc6f9
Fix improt
coreyjadams Nov 7, 2025
314f1b2
Healpix tests are working
coreyjadams Nov 7, 2025
9c7d287
Domino and unet working
coreyjadams Nov 7, 2025
0012209
Refactor (#1216)
coreyjadams Nov 7, 2025
32e1dce
Update activations path in dlwp tests (#1217)
pzharrington Nov 7, 2025
afa903f
Updating to address some test issues
coreyjadams Nov 10, 2025
91ceb0a
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 10, 2025
f9130a6
Merge branch 'main' into v2.0-refactor
coreyjadams Nov 10, 2025
ceb1eb8
Merge branch 'main' into refactor
coreyjadams Nov 10, 2025
0592d80
MGN tests passing again
coreyjadams Nov 10, 2025
857b3db
Most graphcast tests passing again
coreyjadams Nov 10, 2025
f89a2fb
Move nd conv layers.
coreyjadams Nov 10, 2025
409200d
update fengwu and pangu
coreyjadams Nov 10, 2025
14b51fd
Update sfno and pix2pix test
coreyjadams Nov 10, 2025
27fd304
update tests for figconvnet, swinrnn, superresnet
coreyjadams Nov 10, 2025
0d22d11
updating more models to pass
coreyjadams Nov 10, 2025
60ba0ce
Update distributed tests, now passing.
coreyjadams Nov 10, 2025
7ec2251
Domain parallel tests now passing.
coreyjadams Nov 11, 2025
d9fe7a4
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 12, 2025
af9e359
Fix active learning imports so tests pass in refactor
coreyjadams Nov 12, 2025
e3b7849
Fix some metric imports
coreyjadams Nov 12, 2025
b1f2ef9
Remove deploy package
coreyjadams Nov 12, 2025
f46ff8c
Remove unused test file
coreyjadams Nov 12, 2025
edd2224
unmigrate these files ... again?
coreyjadams Nov 12, 2025
1c769e3
Update import linter.
coreyjadams Nov 12, 2025
b9aa3dd
Refactor (#1224)
coreyjadams Nov 12, 2025
8d8255a
Merge branch 'main' into refactor
coreyjadams Nov 12, 2025
8b266b0
Cleaning up diffusion models. Not quite done yet.
coreyjadams Nov 12, 2025
8a8a05a
Merge branch 'main' into refactor
coreyjadams Nov 12, 2025
9b0d40d
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 13, 2025
ff0aacf
Restore deleted files
coreyjadams Nov 13, 2025
f11fcd7
Updating more tests.
coreyjadams Nov 13, 2025
9e32712
Further updates to tests. Datapipes almost working.
coreyjadams Nov 14, 2025
4fe41b9
Refactor (#1231)
coreyjadams Nov 14, 2025
0b78d6c
Merge branch 'NVIDIA:main' into refactor
coreyjadams Nov 17, 2025
ac1fcef
update import paths
coreyjadams Nov 17, 2025
d81ee43
Starting to clean up dependency tree.
coreyjadams Nov 18, 2025
dff27b3
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 18, 2025
8a0a3a5
Refactor (#1233)
coreyjadams Nov 18, 2025
3cb9a02
Added coding standards for model implementations as a custom context …
CharlelieLrt Nov 18, 2025
d7bcd0d
Fixing and adjusting a broad suite of tests.
coreyjadams Nov 19, 2025
d32879d
Merge branch 'NVIDIA:main' into refactor
coreyjadams Nov 19, 2025
c4ef437
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 19, 2025
b3b7786
Update test/domain_parallel/conftest.py
coreyjadams Nov 19, 2025
af41fdf
Minor fix
coreyjadams Nov 19, 2025
611a029
Refactor (#1234)
coreyjadams Nov 19, 2025
58c909c
Merge branch 'main' into v2.0-refactor
coreyjadams Nov 19, 2025
17ff6de
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 19, 2025
e83ea99
Not seeing any errors in testing ...
coreyjadams Nov 19, 2025
ec163e1
Breakdown of rules into smaller rules (#1236)
CharlelieLrt Nov 19, 2025
15a04f1
Merge branch 'NVIDIA:main' into refactor
coreyjadams Nov 20, 2025
42e4b40
Refactor (#1240)
coreyjadams Nov 20, 2025
ff8ddac
Merge branch 'main' into v2.0-refactor
coreyjadams Nov 20, 2025
51c0ccb
Merge branch 'main' into refactor
coreyjadams Nov 24, 2025
e16f9f2
Refactor (#1247)
coreyjadams Nov 24, 2025
60ccc72
Enable import linting on internal imports.
coreyjadams Nov 24, 2025
9b62b7d
Remove ensure_available function, it's confusing
coreyjadams Nov 24, 2025
f05150f
Add logging imports to utils, and fix imports in examples.
coreyjadams Nov 24, 2025
64d731f
Update imports in minimal examples
coreyjadams Nov 24, 2025
725ecfe
Update structural mechanics examples
coreyjadams Nov 24, 2025
d8e5f05
Update import paths: reservoir_sim
coreyjadams Nov 24, 2025
666be4b
Update import paths: additive manufacturing
coreyjadams Nov 24, 2025
19b8afd
Update import paths: topodiff
coreyjadams Nov 24, 2025
824c76a
Update import paths: weather part 1
coreyjadams Nov 24, 2025
641c110
Update import paths: weather part 2
coreyjadams Nov 24, 2025
2e056db
Update import paths: molecular dynamics
coreyjadams Nov 24, 2025
6a9f6e6
Update import paths: geophysics
coreyjadams Nov 24, 2025
b874e4e
Update import paths: cfd + external_aero 1
coreyjadams Nov 24, 2025
23f2955
Update import paths: cfd + external_aero 2
coreyjadams Nov 24, 2025
581c79a
Remove more DGL examples
coreyjadams Nov 24, 2025
6d780d7
Remove more DGL examples
coreyjadams Nov 24, 2025
7763d96
cfd examples 3
coreyjadams Nov 24, 2025
53fa1cb
Last batch of example import fixes!
coreyjadams Nov 24, 2025
1cd3ada
Merge branch 'v2.0-refactor' into refactor
coreyjadams Nov 24, 2025
5fdcf0f
Enforce and protect external deps in utils.
coreyjadams Nov 25, 2025
b5842e3
Remove DGL. :party:
coreyjadams Nov 25, 2025
da742e7
Don't force models yet
coreyjadams Nov 25, 2025
6c872a0
Refactor (#1249)
coreyjadams Nov 25, 2025
363126a
Automated model registry (#1252)
CharlelieLrt Nov 26, 2025
76a29ef
Metadata name deprecation (#1257)
CharlelieLrt Nov 26, 2025
942c375
Merge main into local refactor
coreyjadams Dec 1, 2025
8d8939d
Refactor (#1258)
coreyjadams Dec 1, 2025
cbc2dd3
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 1, 2025
170efa7
Merge branch 'v2.0-refactor' into refactor
coreyjadams Dec 1, 2025
8898450
Remove IPDB
coreyjadams Dec 1, 2025
8aa8dd9
Few more dep fixes.
coreyjadams Dec 1, 2025
70d9135
Merge branch 'main' into refactor
coreyjadams Dec 2, 2025
ec69852
Refactor (#1261)
coreyjadams Dec 2, 2025
17788b0
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 2, 2025
3c03b08
Add external import coding standards.
coreyjadams Dec 2, 2025
a842398
Update external import standards.
coreyjadams Dec 3, 2025
dae0942
Ensure vtk functions are protected.
coreyjadams Dec 3, 2025
042f7ea
Protect pyvista import
coreyjadams Dec 3, 2025
5bb0e6f
Closing more import gaps
coreyjadams Dec 3, 2025
d35d5c7
Remove DGL from meshgraphkan
coreyjadams Dec 3, 2025
12b98d8
All models now comply with external import linting.
coreyjadams Dec 3, 2025
a879e8d
Remove DGL datapipes
coreyjadams Dec 3, 2025
b200b50
cae datapipes in compliance
coreyjadams Dec 3, 2025
cb1766c
Update pyproject.toml
coreyjadams Dec 3, 2025
d339e1f
Add version numbers to deps
coreyjadams Dec 3, 2025
aad176c
Refactor (#1261)
coreyjadams Dec 3, 2025
6c9cebd
Merge branch 'refactor' into v2.0-refactor
coreyjadams Dec 3, 2025
7422e4c
fix import error from wandb
coreyjadams Dec 3, 2025
75490ea
remove instance check
coreyjadams Dec 3, 2025
ddf6ea9
Initial restructure
CharlelieLrt Dec 5, 2025
8e634f9
Completed restructure of diffusion package
CharlelieLrt Dec 5, 2025
1f66eb6
UV <---> Pip must stay in sync. (#1264)
coreyjadams Dec 8, 2025
ab46322
Fix broken imports
coreyjadams Dec 8, 2025
c8c4da6
Fix README links in transolver and domino examples (#1259)
dran-dev Dec 4, 2025
9132858
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 8, 2025
770589b
Add xarray, timm to core deps
coreyjadams Dec 8, 2025
e07fbd2
update import
coreyjadams Dec 9, 2025
6f92470
Somehow, a number of import protections got broken
coreyjadams Dec 9, 2025
de56395
Automatically select CPU or CPU+CUDA instead of decorating every test.
coreyjadams Dec 10, 2025
281e90c
ensure te installed for serialization test
coreyjadams Dec 10, 2025
289c11d
All CPU tests are passing
coreyjadams Dec 10, 2025
d07eab2
Remove DGL/PyG equivalency tests (#1273)
Alexey-Kamenev Dec 10, 2025
c6d6525
Install ci (#1274)
coreyjadams Dec 12, 2025
ea5ab3a
Remove TensorFlow dependency in Vortex Shedding and Lagrangian MGN ex…
Alexey-Kamenev Dec 12, 2025
18f5872
Change registry behavior and list all models as entry points (#1278)
CharlelieLrt Dec 12, 2025
1c2fa2f
Renamed LearnedSimulator into VGFNLearnedSimulator
CharlelieLrt Dec 13, 2025
c2464d3
Fix tests + improve docs for new register arg in from_torch
CharlelieLrt Dec 13, 2025
6305bb8
Remove physicsnemo.model.Module remaining items
coreyjadams Dec 15, 2025
d1ca859
Remove incorrect meta import
coreyjadams Dec 15, 2025
7753a55
Remove incorrect comment
coreyjadams Dec 15, 2025
8c7c08a
Fix linting errors
coreyjadams Dec 16, 2025
c4b71ea
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 16, 2025
8e55518
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 16, 2025
28d1871
Fixing some linting errors
coreyjadams Dec 16, 2025
35235fe
More linter errors
coreyjadams Dec 16, 2025
7f10726
One more.
coreyjadams Dec 16, 2025
b770b23
Update knn tests
coreyjadams Dec 16, 2025
5667675
Purge pylib cugraphops
coreyjadams Dec 16, 2025
ec6e35b
Remove more cugraphops paths.
coreyjadams Dec 16, 2025
e028a59
Trying to close some CI errors.
coreyjadams Dec 16, 2025
d83168c
Fixing more CI issues
coreyjadams Dec 16, 2025
5bf2fb8
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 16, 2025
b371bba
Fix MGN tests (#1281)
Alexey-Kamenev Dec 16, 2025
b860fcd
Fix apex issues on CPU with a diffusion-specific device fixture.
coreyjadams Dec 16, 2025
bfb9d43
Fixing shard tensor import; adjusting pytorch geometric import point …
coreyjadams Dec 16, 2025
b577342
Fixing more imports.
coreyjadams Dec 16, 2025
85c56e0
fix one or two more
coreyjadams Dec 16, 2025
cd0fb4c
Merge branch 'v2.0-refactor' into restructure-diffusion-subpackage
CharlelieLrt Dec 16, 2025
496b17e
Fix MGK, HMGN tests (#1282)
Alexey-Kamenev Dec 17, 2025
f390748
Fix import error
coreyjadams Dec 17, 2025
6a9f5cb
Remove cugraphops
coreyjadams Dec 17, 2025
76c6bd5
Fix many tests
coreyjadams Dec 17, 2025
6aa5c22
Add migration guide early draft. Update external imports.
coreyjadams Dec 17, 2025
020d928
Attempting to fix the last failing tests.
coreyjadams Dec 17, 2025
23ae40e
Add pre-commit action. (#1286)
coreyjadams Dec 17, 2025
d5fc130
Tweak the CI install and testing of imports / docstrings
coreyjadams Dec 17, 2025
8441672
Wow, the tests were not tied to ANY timezone. It only passes in UTC....
coreyjadams Dec 17, 2025
05f839b
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 17, 2025
3a74e58
fix all but 2 docstring tests
coreyjadams Dec 17, 2025
25f8b56
Merge branch 'main' into v2.0-refactor
coreyjadams Dec 18, 2025
dc03aab
Resolve circular import + fix linting errors.
coreyjadams Dec 17, 2025
bdded65
Fixed broken Group Norm
CharlelieLrt Dec 18, 2025
f431779
Merge branch 'v2.0-refactor' into restructure-diffusion-subpackage
CharlelieLrt Dec 18, 2025
ea2314e
Merge branch 'main' into restructure-diffusion-subpackage
CharlelieLrt Dec 19, 2025
ea3c105
Added diffusion.generate
CharlelieLrt Dec 19, 2025
5bb7a49
Added future feature and deprecation warnings for diffusion module
CharlelieLrt Dec 19, 2025
7388bd3
Defined import-linter contracts for physicsnemo.diffusion
CharlelieLrt Dec 19, 2025
2bea331
Updated PR template with missing item
CharlelieLrt Dec 19, 2025
b3518d0
Added missing diffusion.generate
CharlelieLrt Dec 19, 2025
dae4fbc
Fixed a few remaining paths physicsnemo.models.diffusion that does no…
CharlelieLrt Dec 19, 2025
1969383
CI tests fixes
CharlelieLrt Dec 19, 2025
b452e5b
mmiranda nvidia style guide Updates diffusion.rst
megnvidia Dec 19, 2025
55fee4b
mmiranda smol style guide Updates physicsnemo.utils.rst
megnvidia Dec 19, 2025
cdd92cf
Fixed checklist in PR template
CharlelieLrt Dec 19, 2025
071a7e3
Deleted comment in .importlinter
CharlelieLrt Dec 19, 2025
707e0d2
Fixed references in diffusion.rst
CharlelieLrt Dec 19, 2025
0bd74ef
Merge branch 'restructure-diffusion-subpackage' of https://github.com…
CharlelieLrt Dec 19, 2025
4aa1f7b
Merge branch 'main' into restructure-diffusion-subpackage
CharlelieLrt Jan 5, 2026
fca12d1
Fix checkpoint loading with Module subclass when known
CharlelieLrt Jan 6, 2026
8f38564
Deleted physicsnemo/compat
CharlelieLrt Jan 6, 2026
08e89f2
Deleted useless comments in flow_reconstruction_diffusion example
CharlelieLrt Jan 6, 2026
df3ad90
Renamed Attantion into UNetAttention
CharlelieLrt Jan 6, 2026
205aa97
Merge branch 'main' into restructure-diffusion-subpackage
CharlelieLrt Jan 6, 2026
6ff908e
Implemented BasePreconditioner
CharlelieLrt Jan 7, 2026
f08a686
Improvements to BaseConditioner docs
CharlelieLrt Jan 7, 2026
dc733a9
Implemented new preconditioners based on BasePerconditioner
CharlelieLrt Jan 7, 2026
fa02d48
Migrated legacy preconditioners to reuse new preconditioners
CharlelieLrt Jan 7, 2026
f7b8494
Initial implementation of tests for preconditioners
CharlelieLrt Jan 8, 2026
047ca44
Added reference data for non-regression CI tests of preconditioners
CharlelieLrt Jan 8, 2026
2866c10
Improvements to preconditioners CI tests
CharlelieLrt Jan 8, 2026
5f9a309
Adedd a few details in BasePreconditioner doctrsing
CharlelieLrt Jan 8, 2026
5af0aff
Merge branch 'main' into diffusion-preconditioners-refactor
CharlelieLrt Jan 8, 2026
17ee57e
Updated CHANGELOG.md
CharlelieLrt Jan 8, 2026
1e2d2a7
Improved documentation of signature requirement in BasePreconditioner
CharlelieLrt Jan 8, 2026
ffcb026
Renamed BasePreconditioner into BaseAffinePreconditioner
CharlelieLrt Jan 9, 2026
5d5c66a
Added DiffusionModel protocol to specify diffusion models signature
CharlelieLrt Jan 9, 2026
993db63
Changed condition argument to TensorDict instead of Dict of tensors
CharlelieLrt Jan 9, 2026
22dac49
Moved all preconditioners scalar attributes to pytorch buffers instea…
CharlelieLrt Jan 9, 2026
9c5f53b
Improvements to make precondtioners tests more robust on GPU
CharlelieLrt Jan 10, 2026
f561aa6
Removed deterministic setting for tests
CharlelieLrt Jan 10, 2026
b53de49
Added examples in docstrings of all precondtioners
CharlelieLrt Jan 10, 2026
edf00f3
Fix bug in docstring examples
CharlelieLrt Jan 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Refactored diffusion preconditioners in
`physicsnemo.diffusion.preconditioners` relying on a new abstract base class
`BaseAffinePreconditioner` for preconditioning schemes using affine
transformations. Existing preconditioners (`VPPrecond`, `VEPrecond`,
`iDDPMPrecond`, `EDMPrecond`) reimplemented based on this new interface.

### Changed

- PhysicsNemo v2.0 contains significant reorganization of tools. Please see
Expand Down
2 changes: 2 additions & 0 deletions physicsnemo/diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import DiffusionModel # noqa: F401
88 changes: 88 additions & 0 deletions physicsnemo/diffusion/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Protocols and type hints for diffusion model interfaces."""

from typing import Any, Protocol, runtime_checkable

import torch
from jaxtyping import Float
from tensordict import TensorDict


@runtime_checkable
class DiffusionModel(Protocol):
r"""
Protocol defining the common interface for diffusion models.

A diffusion model is any neural network or function that transforms a noisy
state ``x`` at diffusion time (or noise level) ``t`` into a prediction.
This protocol defines the standard interface that all diffusion models must
satisfy.

Any model or function that implements this interface can be used with
preconditioners, losses, samplers, and other diffusion utilities.

The interface is **prediction-agnostic**: whether your model predicts
clean data (:math:`\mathbf{x}_0`), noise (:math:`\epsilon`), score
(:math:`\nabla \log p`), or velocity (:math:`\mathbf{v}`), the signature
remains the same.

Examples
--------
>>> import torch
>>> import torch.nn.functional as F
>>> from physicsnemo.diffusion import DiffusionModel
>>>
>>> class Denoiser:
... def __call__(self, x, t, condition, **kwargs):
... return F.relu(x)
...
>>> isinstance(Denoiser(), DiffusionModel)
True
"""

def __call__(
self,
x: Float[torch.Tensor, "B *dims"], # noqa: F821
t: Float[torch.Tensor, "B"], # noqa: F821
condition: TensorDict,
**model_kwargs: Any,
) -> Float[torch.Tensor, "B *dims"]: # noqa: F821
r"""
Forward pass of the diffusion model.

Parameters
----------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional
dimensions (e.g., channels and spatial dimensions).
t : torch.Tensor
Diffusion time or noise level tensor of shape :math:`(B,)`.
condition : TensorDict
TensorDict containing conditioning tensors. The TensorDict should
have batch size :math:`B` matching that of ``x``. If the model is
unconditional, the condition should be the empty ``TensorDict()``.
**model_kwargs : Any
Additional keyword arguments specific to the model implementation.

Returns
-------
torch.Tensor
Model output with the same shape as ``x``.
"""
...
7 changes: 7 additions & 0 deletions physicsnemo/diffusion/preconditioners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,10 @@
VPPrecond,
iDDPMPrecond,
)
from .preconditioners import ( # noqa: F401
BaseAffinePreconditioner,
EDMPreconditioner,
IDDPMPreconditioner,
VEPreconditioner,
VPPreconditioner,
)
Loading