diff --git a/CHANGELOG.md b/CHANGELOG.md index f03cb23f45..bfbda9bef6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Refactored diffusion preconditioners in + `physicsnemo.diffusion.preconditioners` relying on a new abstract base class + `BaseAffinePreconditioner` for preconditioning schemes using affine + transformations. Existing preconditioners (`VPPrecond`, `VEPrecond`, + `iDDPMPrecond`, `EDMPrecond`) reimplemented based on this new interface. + ### Changed - PhysicsNemo v2.0 contains significant reorganization of tools. Please see diff --git a/physicsnemo/diffusion/__init__.py b/physicsnemo/diffusion/__init__.py index b2340c62ce..f8f8b83348 100644 --- a/physicsnemo/diffusion/__init__.py +++ b/physicsnemo/diffusion/__init__.py @@ -13,3 +13,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .base import DiffusionModel # noqa: F401 diff --git a/physicsnemo/diffusion/base.py b/physicsnemo/diffusion/base.py new file mode 100644 index 0000000000..2f80e955cc --- /dev/null +++ b/physicsnemo/diffusion/base.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Protocols and type hints for diffusion model interfaces.""" + +from typing import Any, Protocol, runtime_checkable + +import torch +from jaxtyping import Float +from tensordict import TensorDict + + +@runtime_checkable +class DiffusionModel(Protocol): + r""" + Protocol defining the common interface for diffusion models. + + A diffusion model is any neural network or function that transforms a noisy + state ``x`` at diffusion time (or noise level) ``t`` into a prediction. + This protocol defines the standard interface that all diffusion models must + satisfy. + + Any model or function that implements this interface can be used with + preconditioners, losses, samplers, and other diffusion utilities. + + The interface is **prediction-agnostic**: whether your model predicts + clean data (:math:`\mathbf{x}_0`), noise (:math:`\epsilon`), score + (:math:`\nabla \log p`), or velocity (:math:`\mathbf{v}`), the signature + remains the same. + + Examples + -------- + >>> import torch + >>> import torch.nn.functional as F + >>> from physicsnemo.diffusion import DiffusionModel + >>> + >>> class Denoiser: + ... def __call__(self, x, t, condition, **kwargs): + ... return F.relu(x) + ... + >>> isinstance(Denoiser(), DiffusionModel) + True + """ + + def __call__( + self, + x: Float[torch.Tensor, "B *dims"], # noqa: F821 + t: Float[torch.Tensor, "B"], # noqa: F821 + condition: TensorDict, + **model_kwargs: Any, + ) -> Float[torch.Tensor, "B *dims"]: # noqa: F821 + r""" + Forward pass of the diffusion model. + + Parameters + ---------- + x : torch.Tensor + Noisy latent state of shape :math:`(B, *)` where :math:`B` is the + batch size and :math:`*` denotes any number of additional + dimensions (e.g., channels and spatial dimensions). + t : torch.Tensor + Diffusion time or noise level tensor of shape :math:`(B,)`. + condition : TensorDict + TensorDict containing conditioning tensors. The TensorDict should + have batch size :math:`B` matching that of ``x``. If the model is + unconditional, the condition should be the empty ``TensorDict()``. + **model_kwargs : Any + Additional keyword arguments specific to the model implementation. + + Returns + ------- + torch.Tensor + Model output with the same shape as ``x``. + """ + ... diff --git a/physicsnemo/diffusion/preconditioners/__init__.py b/physicsnemo/diffusion/preconditioners/__init__.py index 3206f0343c..9640db76b1 100644 --- a/physicsnemo/diffusion/preconditioners/__init__.py +++ b/physicsnemo/diffusion/preconditioners/__init__.py @@ -24,3 +24,10 @@ VPPrecond, iDDPMPrecond, ) +from .preconditioners import ( # noqa: F401 + BaseAffinePreconditioner, + EDMPreconditioner, + IDDPMPreconditioner, + VEPreconditioner, + VPPreconditioner, +) diff --git a/physicsnemo/diffusion/preconditioners/legacy.py b/physicsnemo/diffusion/preconditioners/legacy.py index 43df64a10c..9615a76bc3 100644 --- a/physicsnemo/diffusion/preconditioners/legacy.py +++ b/physicsnemo/diffusion/preconditioners/legacy.py @@ -32,6 +32,12 @@ from physicsnemo.core.warnings import LegacyFeatureWarning from ._utils import _wrapped_property +from .preconditioners import ( + EDMPreconditioner, + IDDPMPreconditioner, + VEPreconditioner, + VPPreconditioner, +) warnings.warn( "The preconditioner classes 'VPPrecond', 'VEPrecond', 'iDDPMPrecond', " @@ -65,7 +71,7 @@ class VPPrecondMetaData(ModelMetaData): auto_grad: bool = False -class VPPrecond(Module): +class VPPrecond(VPPreconditioner): """ Preconditioning corresponding to the variance preserving (VP) formulation. @@ -112,25 +118,34 @@ def __init__( model_type: str = "SongUNet", **model_kwargs: dict, ): - super().__init__(meta=VPPrecondMetaData) - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.beta_d = beta_d - self.beta_min = beta_min - self.M = M - self.epsilon_t = epsilon_t - self.sigma_min = float(self.sigma(epsilon_t)) - self.sigma_max = float(self.sigma(1)) + # Create the underlying model model_class = getattr(network_module, model_type) - self.model = model_class( + model = model_class( img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs, - ) # TODO needs better handling + ) + + # Initialize parent class with model and VP parameters + super().__init__( + model=model, + beta_d=beta_d, + beta_min=beta_min, + M=M, + ) + # Override meta from parent + self.meta = VPPrecondMetaData + + # Store legacy-specific attributes + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.epsilon_t = epsilon_t + self.sigma_min = float(self.sigma(torch.tensor(epsilon_t))) + self.sigma_max = float(self.sigma(torch.tensor(1.0))) def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): x = x.to(torch.float32) @@ -148,10 +163,8 @@ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs) else torch.float32 ) - c_skip = 1 - c_out = -sigma - c_in = 1 / (sigma**2 + 1).sqrt() - c_noise = (self.M - 1) * self.sigma_inv(sigma) + # Use parent's compute_coefficients method + c_in, c_noise, c_out, c_skip = self.compute_coefficients(sigma) F_x = self.model( (c_in * x).to(dtype), @@ -167,7 +180,7 @@ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs) D_x = c_skip * x + c_out * F_x.to(torch.float32) return D_x - def sigma(self, t: Union[float, torch.Tensor]): + def sigma(self, t: Union[float, torch.Tensor]) -> torch.Tensor: """ Compute the sigma(t) value for a given t based on the VP formulation. @@ -185,7 +198,7 @@ def sigma(self, t: Union[float, torch.Tensor]): The computed sigma(t) value(s). """ t = torch.as_tensor(t) - return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() + return super().sigma(t) def sigma_inv(self, sigma: Union[float, torch.Tensor]): """ @@ -247,7 +260,7 @@ class VEPrecondMetaData(ModelMetaData): auto_grad: bool = False -class VEPrecond(Module): +class VEPrecond(VEPreconditioner): """ Preconditioning corresponding to the variance exploding (VE) formulation. @@ -288,21 +301,28 @@ def __init__( model_type: str = "SongUNet", **model_kwargs: dict, ): - super().__init__(meta=VEPrecondMetaData) - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max + # Create the underlying model model_class = getattr(network_module, model_type) - self.model = model_class( + model = model_class( img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs, - ) # TODO needs better handling + ) + + # Initialize parent class with model + super().__init__(model=model) + # Override meta from parent + self.meta = VEPrecondMetaData + + # Store legacy-specific attributes + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): x = x.to(torch.float32) @@ -320,10 +340,8 @@ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs) else torch.float32 ) - c_skip = 1 - c_out = sigma - c_in = 1 - c_noise = (0.5 * sigma).log() + # Use parent's compute_coefficients method + c_in, c_noise, c_out, c_skip = self.compute_coefficients(sigma) F_x = self.model( (c_in * x).to(dtype), @@ -375,7 +393,7 @@ class iDDPMPrecondMetaData(ModelMetaData): auto_grad: bool = False -class iDDPMPrecond(Module): +class iDDPMPrecond(IDDPMPreconditioner): """ Preconditioning corresponding to the improved DDPM (iDDPM) formulation. @@ -419,33 +437,34 @@ def __init__( model_type="DhariwalUNet", **model_kwargs, ): - super().__init__(meta=iDDPMPrecondMetaData) - self.img_resolution = img_resolution - self.img_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.C_1 = C_1 - self.C_2 = C_2 - self.M = M + # Create the underlying model model_class = getattr(network_module, model_type) - self.model = model_class( + model = model_class( img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels * 2, label_dim=label_dim, **model_kwargs, - ) # TODO needs better handling + ) + + # Initialize parent class with model and iDDPM parameters + super().__init__( + model=model, + C_1=C_1, + C_2=C_2, + M=M, + ) + # Override meta from parent + self.meta = iDDPMPrecondMetaData - u = torch.zeros(M + 1) - for j in range(M, 0, -1): # M, ..., 1 - u[j - 1] = ( - (u[j] ** 2 + 1) - / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - - 1 - ).sqrt() - self.register_buffer("u", u) - self.sigma_min = float(u[M - 1]) - self.sigma_max = float(u[0]) + # Store legacy-specific attributes + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + # Use the u buffer from parent to compute sigma_min and sigma_max + self.sigma_min = float(self.u[M - 1]) + self.sigma_max = float(self.u[0]) def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): x = x.to(torch.float32) @@ -463,12 +482,8 @@ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs) else torch.float32 ) - c_skip = 1 - c_out = -sigma - c_in = 1 / (sigma**2 + 1).sqrt() - c_noise = ( - self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) - ) + # Compute coefficients using parent's method + c_in, c_noise, c_out, c_skip = self.compute_coefficients(sigma) F_x = self.model( (c_in * x).to(dtype), @@ -548,7 +563,7 @@ class EDMPrecondMetaData(ModelMetaData): auto_grad: bool = False -class EDMPrecond(Module): +class EDMPrecond(EDMPreconditioner): """ Improved preconditioning proposed in the paper "Elucidating the Design Space of Diffusion-Based Generative Models" (EDM) @@ -605,33 +620,38 @@ def __init__( img_out_channels=None, **model_kwargs, ): - super().__init__(meta=EDMPrecondMetaData) - self.img_resolution = img_resolution - if img_in_channels is not None: - img_in_channels = img_in_channels - else: + # Resolve input/output channels + if img_in_channels is None: img_in_channels = img_channels - if img_out_channels is not None: - img_out_channels = img_out_channels - else: + if img_out_channels is None: img_out_channels = img_channels - self.label_dim = label_dim - self.use_fp16 = use_fp16 - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.sigma_data = sigma_data - + # Create the underlying model model_class = getattr(network_module, model_type) - self.model = model_class( + model = model_class( img_resolution=img_resolution, in_channels=img_in_channels, out_channels=img_out_channels, label_dim=label_dim, **model_kwargs, - ) # TODO needs better handling + ) - def forward( + # Initialize parent class with model and sigma_data + super().__init__( + model=model, + sigma_data=sigma_data, + ) + # Override meta from parent + self.meta = EDMPrecondMetaData + + # Store legacy-specific attributes + self.img_resolution = img_resolution + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + def forward( # type: ignore[override] self, x, sigma, @@ -655,10 +675,8 @@ def forward( else torch.float32 ) - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() - c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() - c_noise = sigma.log() / 4 + # Use parent's compute_coefficients method + c_in, c_noise, c_out, c_skip = self.compute_coefficients(sigma) arg = c_in * x diff --git a/physicsnemo/diffusion/preconditioners/preconditioners.py b/physicsnemo/diffusion/preconditioners/preconditioners.py index cfaad66ca9..b574e15f44 100644 --- a/physicsnemo/diffusion/preconditioners/preconditioners.py +++ b/physicsnemo/diffusion/preconditioners/preconditioners.py @@ -14,14 +14,800 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings +import math +from abc import ABC, abstractmethod +from typing import Any, Tuple -from physicsnemo.core.warnings import FutureFeatureWarning +import torch +from tensordict import TensorDict -warnings.warn( - "The 'physicsnemo.diffusion.preconditioners.preconditioners' module is a " - "placeholder for future functionality that will be implemented in an " - "upcoming release.", - FutureFeatureWarning, - stacklevel=2, -) +from physicsnemo.core.meta import ModelMetaData +from physicsnemo.core.module import Module + +# TODO: once noise schedulers are implemeneted, some of the methods they define +# can be reused here, e.g. preconditioner.sigma = noise_scheduler.sigma for the +# noise schedule. This would allow to avoid duplicate code between the +# preconditioners and the noise schedulers. Particularly for the iDDPM +# preconditioner, that requires more computations than the other +# preconditioners. + + +class BaseAffinePreconditioner(Module, ABC): + r""" + Abstract base class for diffusion model preconditioners using an affine + transformation. + + This class provides a standardized interface for implementing + preconditioners that use affine transformations of the model + input and output. + + The preconditioner wraps a neural network model :math:`F` and applies + a preconditioning formula to transform the network output to produce + the preconditioned output :math:`D(\mathbf{x}, t)` according to: + + .. math:: + + D(\mathbf{x}, t) = c_{\text{skip}}(t) \mathbf{x} + + c_{\text{out}}(t) F(c_{\text{in}}(t) \mathbf{x}, c_{\text{noise}}(t)) + + where: + + - :math:`c_{\text{in}}(t)`: Input scaling coefficient + - :math:`c_{\text{noise}}(t)`: Noise conditioning value + - :math:`c_{\text{out}}(t)`: Output scaling coefficient + - :math:`c_{\text{skip}}(t)`: Skip connection scaling coefficient + + and where :math:`\mathbf{x}` is the latent state and :math:`t` is the + diffusion time. + + The wrapped model :math:`F` must be an instance of + :class:`~physicsnemo.core.Module` that satisfies the + :class:`~physicsnemo.diffusion.DiffusionModel` interface, with the + following signature: + + .. code-block:: python + + model( + x: torch.Tensor, # Shape: (B, *) + t: torch.Tensor, # Shape: (B,) + condition: TensorDict, + **model_kwargs: Any, + ) -> torch.Tensor # Shape: (B, *) + + The preconditioner is agnostic to the prediction target of the wrapped + model :math:`F`. The same preconditioning formula is applied regardless of + whether the model is an :math:`\mathbf{x}_0`-predictor, an + :math:`\epsilon`-predictor, a score predictor, or a + :math:`\mathbf{v}`-predictor. + + .. note:: + + The preconditioner itself also satisfies the + :class:`~physicsnemo.diffusion.DiffusionModel` interface, meaning it + does not change the signature of the wrapped model :math:`F`, and it + can be used anywhere a diffusion model is expected. + + Parameters + ---------- + model : physicsnemo.Module + The underlying neural network model :math:`F` to wrap with the + signature described above. + meta : ModelMetaData, optional + Meta data class for storing info regarding model, by default None. + Subclasses can pass their own metadata. + + Forward + ------- + x : torch.Tensor + Noisy latent state of shape :math:`(B, *)` where :math:`B` is the + batch size and :math:`*` denotes any number of additional dimensions. + t : torch.Tensor + Diffusion time tensor of shape :math:`(B,)`. + condition : TensorDict + TensorDict containing conditioning tensors with batch size :math:`B` + matching that of ``x``. Passed to the wrapped ``model`` unchanged. + **model_kwargs : Any + Additional keyword arguments passed to the underlying model. + + Outputs + ------- + torch.Tensor + Preconditioned model output with the same shape as the original model + output. + + .. note:: + + To implement a new preconditioner, a subclass of + :class:`BaseAffinePreconditioner` must be defined, and some methods + have to be implemented: + + - Subclasses must implement the :meth:`compute_coefficients` method to + define the specific preconditioning scheme. + + - A :meth:`sigma` method can optionally be implemented. + If a subclass implements the :meth:`sigma` method, the diffusion time + :math:`t` is first transformed to a noise level :math:`\sigma(t)` + before being passed to :meth:`compute_coefficients`. This allows + implementing preconditioners for different noise schedules while + keeping the same preconditioning interface, in particular for + preconditioning schemes based on noise level (that is + :math:`c_{\text{in}}(\sigma)`, + :math:`c_{\text{noise}}(\sigma)`, :math:`c_{\text{out}}(\sigma)`, + :math:`c_{\text{skip}}(\sigma)` instead of :math:`c_{\text{in}}(t)`, + :math:`c_{\text{noise}}(t)`, :math:`c_{\text{out}}(t)`, + :math:`c_{\text{skip}}(t)`). + + - The ``forward`` method of the preconditioner *should not* be + overriden. + + .. note:: + + The arguments ``t`` of the preconditioner forward method is always + assumed to be the diffusion time. For preconditioning schemes based + on noise level the noise level :math:`\sigma(t)` is computed internally + using the :meth:`sigma` method. + + Examples + -------- + The following example shows how to implement a classical EDM + preconditioner. For EDM, there is no need to implement the :meth:`sigma` + method since :math:`\sigma(t) = t` (noise level and diffusion time are the + same). + + We first define a simple model to wrap: + + >>> import torch + >>> from physicsnemo.nn import Module + >>> class SimpleModel(Module): + ... def __init__(self, channels: int): + ... super().__init__() + ... self.channels = channels + ... self.net = torch.nn.Conv2d(channels, channels, 1) + ... + ... def forward(self, x, t, condition): + ... return self.net(x) + + Now we define the EDM preconditioner: + + >>> from physicsnemo.diffusion.preconditioners import ( + ... BaseAffinePreconditioner, + ... ) + >>> class EDMPreconditioner(BaseAffinePreconditioner): + ... def __init__(self, model, sigma_data: float = 0.5): + ... super().__init__(model) + ... self.sigma_data = sigma_data + ... + ... def compute_coefficients(self, t: torch.Tensor): + ... # For EDM sigma(t) = t, so the argument passed to + ... # compute_coefficients is already sigma(t) + ... sigma_data = self.sigma_data + ... c_skip = sigma_data**2 / (t**2 + sigma_data**2) + ... c_out = t * sigma_data / (t**2 + sigma_data**2).sqrt() + ... c_in = 1 / (sigma_data**2 + t**2).sqrt() + ... c_noise = t.log() / 4 + ... return c_in, c_noise, c_out, c_skip + ... + >>> from tensordict import TensorDict + >>> model = SimpleModel(channels=3) + >>> precond = EDMPreconditioner(model, sigma_data=0.5) + >>> x = torch.randn(2, 3, 16, 16) + >>> t = torch.rand(2) + >>> out = precond(x, t, TensorDict()) + >>> out.shape + torch.Size([2, 3, 16, 16]) + + The following example shows how to override the :meth:`sigma` method to + implement a Variance Exploding (VE) preconditioner where + :math:`\sigma(t) = \sqrt{t}`. + + >>> class VEPreconditioner(BaseAffinePreconditioner): + ... def __init__(self, model): + ... super().__init__(model) + ... + ... def sigma(self, t: torch.Tensor) -> torch.Tensor: + ... # Override sigma to implement VE noise schedule + ... return t.sqrt() + ... + ... def compute_coefficients(self, sigma: torch.Tensor): + ... # Here the argument passed to compute_coefficients is + ... # sigma(t) = sqrt(t) due to override of the sigma method + ... # due to override of the sigma method + ... c_skip = torch.ones_like(sigma) + ... c_out = sigma + ... c_in = torch.ones_like(sigma) + ... c_noise = (0.5 * sigma).log() + ... return c_in, c_noise, c_out, c_skip + ... + >>> precond_ve = VEPreconditioner(model) + >>> out_ve = precond_ve(x, t, TensorDict()) + >>> out_ve.shape + torch.Size([2, 3, 16, 16]) + """ + + def __init__( + self, + model: Module, + meta: ModelMetaData | None = None, + ) -> None: + super().__init__() + self.meta = meta + self.model = model + + @abstractmethod + def compute_coefficients( + self, t: torch.Tensor, / + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Compute the preconditioning coefficients for a given diffusion time + :math:`t` or noise level :math:`\sigma`. + + This abstract method must be implemented by subclasses to define + the specific preconditioning scheme. + + Parameters + ---------- + t : torch.Tensor + Diffusion time (or noise level if :meth:`sigma` is + implemented) tensor of shape :math:`(B, 1, ..., 1)` where + :math:`B` is the batch size and the trailing singleton + dimensions match the spatial dimensions of the latent state + ``x`` for broadcasting. + + Returns + ------- + c_in : torch.Tensor + Input scaling coefficient of shape :math:`(B, 1, ..., 1)`. + c_noise : torch.Tensor + Noise conditioning value of shape :math:`(B, 1, ..., 1)`. + c_out : torch.Tensor + Output scaling coefficient of shape :math:`(B, 1, ..., 1)`. + c_skip : torch.Tensor + Skip connection scaling coefficient of shape + :math:`(B, 1, ..., 1)`. + """ + ... + + def sigma(self, t: torch.Tensor) -> torch.Tensor: + r""" + Map diffusion time :math:`t` to noise level :math:`\sigma(t)`. + + By default, this is the identity function :math:`\sigma(t) = t`. + Subclasses can override this to implement preconditioners for different + noise schedules. + + When overridden, the output of this method is passed to + :meth:`compute_coefficients` instead of the raw time ``t``. + + Parameters + ---------- + t : torch.Tensor + Diffusion time tensor of shape :math:`(B,)` where + :math:`B` is the batch size. + + Returns + ------- + torch.Tensor + Noise level :math:`\sigma(t)` of shape :math:`(B,)`. + """ + return t + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + condition: TensorDict, + **model_kwargs: Any, + ) -> torch.Tensor: + if not torch.compiler.is_compiling(): + B = x.shape[0] + if t.shape != (B,): + raise ValueError( + f"Expected t to have shape ({B},) matching batch size of " + f"x, but got {t.shape}." + ) + if condition.batch_size and condition.batch_size[0] != B: + cond_B = condition.batch_size[0] + raise ValueError( + f"Condition TensorDict has batch size {cond_B} " + f"but expected {B} to match x." + ) + + # Map time step to noise level via sigma method + sigma_t = self.sigma(t).reshape(-1, *([1] * (x.ndim - 1))) + + # Compute preconditioning coefficients + c_in, c_noise, c_out, c_skip = self.compute_coefficients(sigma_t) + + # Forward through the underlying model + F_x = self.model( + c_in * x, + c_noise.flatten(), + condition, + **model_kwargs, + ) + + D_x = c_skip * x + c_out * F_x + + return D_x + + +class VPPreconditioner(BaseAffinePreconditioner): + r""" + Variance Preserving (VP) preconditioner. + + Implements the preconditioning scheme from the VP formulation of + score-based generative models. + + The noise schedule is: + + .. math:: + + \sigma(t) = \sqrt{\exp\left(\frac{\beta_d}{2} t^2 + + \beta_{\min} t\right) - 1} + + The preconditioning coefficients are: + + .. math:: + + c_{\text{skip}} &= 1 \\ + c_{\text{out}} &= -\sigma \\ + c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\ + c_{\text{noise}} &= (M - 1) \cdot \sigma^{-1}(\sigma) + + Parameters + ---------- + model : physicsnemo.Module + The underlying neural network model to wrap with signature described in + :class:`BaseAffinePreconditioner`. + beta_d : float, optional + Extent of the noise level schedule, by default 19.9. + beta_min : float, optional + Initial slope of the noise level schedule, by default 0.1. + M : int, optional + Number of discretization steps in the DDPM formulation, + by default 1000. + + Forward + ------- + x : torch.Tensor + Noisy latent state of shape :math:`(B, *)` where :math:`B` is the + batch size and :math:`*` denotes any number of additional dimensions. + t : torch.Tensor + Diffusion time tensor of shape :math:`(B,)`. + condition : TensorDict + TensorDict containing conditioning tensors with batch size :math:`B` + matching that of ``x``. Passed to the wrapped ``model`` unchanged. + **model_kwargs : Any + Additional keyword arguments passed to the underlying model. + + Outputs + ------- + torch.Tensor + Preconditioned model output with the same shape as the original model + output. + + Note + ---- + Reference: `Score-Based Generative Modeling through Stochastic + Differential Equations `_ + + Examples + -------- + >>> import torch + >>> from tensordict import TensorDict + >>> from physicsnemo.core import Module + >>> # Define a simple model satisfying the diffusion model interface + >>> class SimpleModel(Module): + ... def __init__(self, channels: int): + ... super().__init__() + ... self.net = torch.nn.Conv2d(channels, channels, 1) + ... def forward(self, x, t, condition): + ... return self.net(x) + >>> model = SimpleModel(channels=3) + >>> precond = VPPreconditioner(model, beta_d=19.9, beta_min=0.1, M=1000) + >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images + >>> t = torch.rand(2) # diffusion time for each sample + >>> condition = TensorDict({}, batch_size=[2]) + >>> out = precond(x, t, condition) + >>> out.shape + torch.Size([2, 3, 16, 16]) + """ + + def __init__( + self, + model: Module, + beta_d: float = 19.9, + beta_min: float = 0.1, + M: int = 1000, + ) -> None: + super().__init__(model) + self.register_buffer("beta_d", torch.tensor(beta_d)) + self.register_buffer("beta_min", torch.tensor(beta_min)) + self.register_buffer("M", torch.tensor(M)) + + def sigma(self, t: torch.Tensor) -> torch.Tensor: + r""" + Compute :math:`\sigma(t)` for the VP formulation. + + Parameters + ---------- + t : torch.Tensor + Diffusion time tensor of shape :math:`(B,)`. + + Returns + ------- + torch.Tensor + Noise level :math:`\sigma(t)` of shape :math:`(B,)`. + """ + exponent = 0.5 * self.beta_d * (t**2) + self.beta_min * t + return (exponent.exp() - 1).sqrt() + + def compute_coefficients( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Compute VP preconditioning coefficients. + + Parameters + ---------- + sigma : torch.Tensor + Noise level tensor of shape :math:`(B, 1, ..., 1)`. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + Preconditioning coefficients (:math:`c_{\text{in}}`, + :math:`c_{\text{noise}}`, :math:`c_{\text{out}}`, + :math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`. + """ + c_skip = torch.ones_like(sigma) + c_out = -sigma + c_in = 1 / (sigma**2 + 1).sqrt() + # Compute t = sigma_inv(sigma) + t = ( + (self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt() + - self.beta_min + ) / self.beta_d + c_noise = (self.M - 1) * t + return c_in, c_noise, c_out, c_skip + + +class VEPreconditioner(BaseAffinePreconditioner): + r""" + Variance Exploding (VE) preconditioner. + + Implements the preconditioning scheme from the VE formulation of + score-based generative models. + + For VE, the noise schedule is identity: :math:`\sigma(t) = t`. + + The preconditioning coefficients are: + + .. math:: + + c_{\text{skip}} &= 1 \\ + c_{\text{out}} &= \sigma \\ + c_{\text{in}} &= 1 \\ + c_{\text{noise}} &= \log(0.5 \cdot \sigma) + + Parameters + ---------- + model : physicsnemo.Module + The underlying neural network model to wrap with signature described in + :class:`BaseAffinePreconditioner`. + + Forward + ------- + x : torch.Tensor + Noisy latent state of shape :math:`(B, *)` where :math:`B` is the + batch size and :math:`*` denotes any number of additional dimensions. + t : torch.Tensor + Diffusion time tensor of shape :math:`(B,)`. + condition : TensorDict + TensorDict containing conditioning tensors with batch size :math:`B` + matching that of ``x``. Passed to the wrapped ``model`` unchanged. + **model_kwargs : Any + Additional keyword arguments passed to the underlying model. + + Outputs + ------- + torch.Tensor + Preconditioned model output with the same shape as the original model + output. + + Note + ---- + Reference: `Score-Based Generative Modeling through Stochastic + Differential Equations `_ + + Examples + -------- + >>> import torch + >>> from tensordict import TensorDict + >>> from physicsnemo.core import Module + >>> # Define a simple model satisfying the diffusion model interface + >>> class SimpleModel(Module): + ... def __init__(self, channels: int): + ... super().__init__() + ... self.net = torch.nn.Conv2d(channels, channels, 1) + ... def forward(self, x, t, condition): + ... return self.net(x) + >>> model = SimpleModel(channels=3) + >>> precond = VEPreconditioner(model) + >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images + >>> t = torch.rand(2) # diffusion time for each sample + >>> condition = TensorDict({}, batch_size=[2]) + >>> out = precond(x, t, condition) + >>> out.shape + torch.Size([2, 3, 16, 16]) + """ + + def __init__(self, model: Module) -> None: + super().__init__(model) + + def compute_coefficients( + self, t: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Compute VE preconditioning coefficients. + + Parameters + ---------- + t : torch.Tensor + Diffusion time tensor of shape :math:`(B, 1, ..., 1)`. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + Preconditioning coefficients (:math:`c_{\text{in}}`, + :math:`c_{\text{noise}}`, :math:`c_{\text{out}}`, + :math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`. + """ + c_skip = torch.ones_like(t) + c_out = t + c_in = torch.ones_like(t) + c_noise = (0.5 * t).log() + return c_in, c_noise, c_out, c_skip + + +class IDDPMPreconditioner(BaseAffinePreconditioner): + r""" + Improved DDPM (iDDPM) preconditioner. + + Implements the preconditioning scheme from the improved DDPM + formulation. + + The preconditioning coefficients are: + + .. math:: + + c_{\text{skip}} &= 1 \\ + c_{\text{out}} &= -\sigma \\ + c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\ + c_{\text{noise}} &= M - 1 - \text{argmin}|\sigma - u_j| + + where :math:`u_j, j = 0, ..., M` are the precomputed noise levels in the + noise schedule. + + Parameters + ---------- + model : physicsnemo.Module + The underlying neural network model to wrap with signature described in + :class:`BaseAffinePreconditioner`. + C_1 : float, optional + Timestep adjustment at low noise levels, by default 0.001. + C_2 : float, optional + Timestep adjustment at high noise levels, by default 0.008. + M : int, optional + Number of discretization steps in the DDPM formulation, + by default 1000. + + Forward + ------- + x : torch.Tensor + Noisy latent state of shape :math:`(B, *)` where :math:`B` is the + batch size and :math:`*` denotes any number of additional dimensions. + t : torch.Tensor + Diffusion time tensor of shape :math:`(B,)`. + condition : TensorDict + TensorDict containing conditioning tensors with batch size :math:`B` + matching that of ``x``. Passed to the wrapped ``model`` unchanged. + **model_kwargs : Any + Additional keyword arguments passed to the underlying model. + + Outputs + ------- + torch.Tensor + Preconditioned model output with the same shape as the original model + output. + + Note + ---- + Reference: `Improved Denoising Diffusion Probabilistic Models + `_ + + Examples + -------- + >>> import torch + >>> from tensordict import TensorDict + >>> from physicsnemo.core import Module + >>> # Define a simple model satisfying the diffusion model interface + >>> class SimpleModel(Module): + ... def __init__(self, channels: int): + ... super().__init__() + ... self.net = torch.nn.Conv2d(channels, channels, 1) + ... def forward(self, x, t, condition): + ... return self.net(x) + >>> model = SimpleModel(channels=3) + >>> precond = IDDPMPreconditioner(model, C_1=0.001, C_2=0.008, M=1000) + >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images + >>> t = torch.rand(2) # diffusion time for each sample + >>> condition = TensorDict({}, batch_size=[2]) + >>> out = precond(x, t, condition) + >>> out.shape + torch.Size([2, 3, 16, 16]) + """ + + def __init__( + self, + model: Module, + C_1: float = 0.001, + C_2: float = 0.008, + M: int = 1000, + ) -> None: + super().__init__(model) + self.register_buffer("C_1", torch.tensor(C_1)) + self.register_buffer("C_2", torch.tensor(C_2)) + self.register_buffer("M", torch.tensor(M)) + + # Precompute the noise level schedule u_j, j = 0, ..., M + u = torch.zeros(M + 1) + for j in range(M, 0, -1): + angle_j = 0.5 * math.pi * j / M / (C_2 + 1) + angle_jm1 = 0.5 * math.pi * (j - 1) / M / (C_2 + 1) + alpha_bar_j = math.sin(angle_j) ** 2 + alpha_bar_jm1 = math.sin(angle_jm1) ** 2 + alpha_ratio = alpha_bar_jm1 / alpha_bar_j + u[j - 1] = ((u[j] ** 2 + 1) / max(alpha_ratio, C_1) - 1).sqrt() + self.register_buffer("u", u) + + def compute_coefficients( + self, t: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Compute iDDPM preconditioning coefficients. + + Parameters + ---------- + t : torch.Tensor + Diffusion time tensor of shape :math:`(B, 1, ..., 1)`. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + Preconditioning coefficients (:math:`c_{\text{in}}`, + :math:`c_{\text{noise}}`, :math:`c_{\text{out}}`, + :math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`. + """ + c_skip = torch.ones_like(t) + c_out = -t + c_in = 1 / (t**2 + 1).sqrt() + + # Round sigma to nearest index in precomputed schedule u + u: torch.Tensor = self.u # type: ignore[assignment] + t_flat = t.reshape(1, -1, 1) + u_reshaped = u.reshape(1, -1, 1) + idx = torch.cdist(t_flat, u_reshaped).argmin(2).reshape(t.shape) + c_noise = self.M - 1 - idx + + return c_in, c_noise, c_out, c_skip + + +class EDMPreconditioner(BaseAffinePreconditioner): + r""" + EDM preconditioner. + + Implements the improved preconditioning scheme proposed in the EDM + paper. + + For EDM, the noise schedule is identity: :math:`\sigma(t) = t`. + + The preconditioning coefficients are: + + .. math:: + + c_{\text{skip}} &= \frac{\sigma_{\text{data}}^2} + {\sigma^2 + \sigma_{\text{data}}^2} \\ + c_{\text{out}} &= \frac{\sigma \cdot \sigma_{\text{data}}} + {\sqrt{\sigma^2 + \sigma_{\text{data}}^2}} \\ + c_{\text{in}} &= \frac{1} + {\sqrt{\sigma_{\text{data}}^2 + \sigma^2}} \\ + c_{\text{noise}} &= \frac{\log(\sigma)}{4} + + Parameters + ---------- + model : physicsnemo.Module + The underlying neural network model to wrap with signature described in + :class:`BaseAffinePreconditioner`. + sigma_data : float, optional + Expected standard deviation of the training data, by default 0.5. + + Forward + ------- + x : torch.Tensor + Noisy latent state of shape :math:`(B, *)` where :math:`B` is the + batch size and :math:`*` denotes any number of additional dimensions. + t : torch.Tensor + Diffusion time tensor of shape :math:`(B,)`. + condition : TensorDict + TensorDict containing conditioning tensors with batch size :math:`B` + matching that of ``x``. Passed to the wrapped ``model`` unchanged. + **model_kwargs : Any + Additional keyword arguments passed to the underlying model. + + Outputs + ------- + torch.Tensor + Preconditioned model output with the same shape as the original model + output. + + Note + ---- + Reference: `Elucidating the Design Space of Diffusion-Based + Generative Models `_ + + Examples + -------- + >>> import torch + >>> from tensordict import TensorDict + >>> from physicsnemo.core import Module + >>> # Define a simple model satisfying the diffusion model interface + >>> class SimpleModel(Module): + ... def __init__(self, channels: int): + ... super().__init__() + ... self.net = torch.nn.Conv2d(channels, channels, 1) + ... def forward(self, x, t, condition): + ... return self.net(x) + >>> model = SimpleModel(channels=3) + >>> precond = EDMPreconditioner(model, sigma_data=0.5) + >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images + >>> t = torch.rand(2) # diffusion time for each sample + >>> condition = TensorDict({}, batch_size=[2]) + >>> out = precond(x, t, condition) + >>> out.shape + torch.Size([2, 3, 16, 16]) + """ + + def __init__( + self, + model: Module, + sigma_data: float = 0.5, + ) -> None: + super().__init__(model) + self.register_buffer("sigma_data", torch.tensor(sigma_data)) + + def compute_coefficients( + self, t: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Compute EDM preconditioning coefficients. + + Parameters + ---------- + t : torch.Tensor + Diffusion time (or noise level, since they are identical for EDM) + of shape :math:`(B, 1, ..., 1)`. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + Preconditioning coefficients (:math:`c_{\text{in}}`, + :math:`c_{\text{noise}}`, :math:`c_{\text{out}}`, + :math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`. + """ + sd = self.sigma_data + c_skip = sd**2 / (t**2 + sd**2) + c_out = t * sd / (t**2 + sd**2).sqrt() + c_in = 1 / (sd**2 + t**2).sqrt() + c_noise = t.log() / 4 + return c_in, c_noise, c_out, c_skip diff --git a/test/diffusion/__init__.py b/test/diffusion/__init__.py new file mode 100644 index 0000000000..b2340c62ce --- /dev/null +++ b/test/diffusion/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/diffusion/data/edm_precond_conv.mdlus b/test/diffusion/data/edm_precond_conv.mdlus new file mode 100644 index 0000000000..65d3559494 Binary files /dev/null and b/test/diffusion/data/edm_precond_conv.mdlus differ diff --git a/test/diffusion/data/edm_precond_conv_coefficients.pth b/test/diffusion/data/edm_precond_conv_coefficients.pth new file mode 100644 index 0000000000..502a7d2cb3 Binary files /dev/null and b/test/diffusion/data/edm_precond_conv_coefficients.pth differ diff --git a/test/diffusion/data/edm_precond_conv_forward.pth b/test/diffusion/data/edm_precond_conv_forward.pth new file mode 100644 index 0000000000..a6421d994b Binary files /dev/null and b/test/diffusion/data/edm_precond_conv_forward.pth differ diff --git a/test/diffusion/data/edm_precond_conv_sigma.pth b/test/diffusion/data/edm_precond_conv_sigma.pth new file mode 100644 index 0000000000..990ffc4ec5 Binary files /dev/null and b/test/diffusion/data/edm_precond_conv_sigma.pth differ diff --git a/test/diffusion/data/edm_precond_linear.mdlus b/test/diffusion/data/edm_precond_linear.mdlus new file mode 100644 index 0000000000..53f07ec6f7 Binary files /dev/null and b/test/diffusion/data/edm_precond_linear.mdlus differ diff --git a/test/diffusion/data/edm_precond_linear_coefficients.pth b/test/diffusion/data/edm_precond_linear_coefficients.pth new file mode 100644 index 0000000000..ea618d80c8 Binary files /dev/null and b/test/diffusion/data/edm_precond_linear_coefficients.pth differ diff --git a/test/diffusion/data/edm_precond_linear_forward.pth b/test/diffusion/data/edm_precond_linear_forward.pth new file mode 100644 index 0000000000..779078c0e2 Binary files /dev/null and b/test/diffusion/data/edm_precond_linear_forward.pth differ diff --git a/test/diffusion/data/edm_precond_linear_sigma.pth b/test/diffusion/data/edm_precond_linear_sigma.pth new file mode 100644 index 0000000000..cfa447e914 Binary files /dev/null and b/test/diffusion/data/edm_precond_linear_sigma.pth differ diff --git a/test/diffusion/data/iddpm_precond_conv.mdlus b/test/diffusion/data/iddpm_precond_conv.mdlus new file mode 100644 index 0000000000..8c76231f13 Binary files /dev/null and b/test/diffusion/data/iddpm_precond_conv.mdlus differ diff --git a/test/diffusion/data/iddpm_precond_conv_coefficients.pth b/test/diffusion/data/iddpm_precond_conv_coefficients.pth new file mode 100644 index 0000000000..0e5c6e03bc Binary files /dev/null and b/test/diffusion/data/iddpm_precond_conv_coefficients.pth differ diff --git a/test/diffusion/data/iddpm_precond_conv_forward.pth b/test/diffusion/data/iddpm_precond_conv_forward.pth new file mode 100644 index 0000000000..7edcf85793 Binary files /dev/null and b/test/diffusion/data/iddpm_precond_conv_forward.pth differ diff --git a/test/diffusion/data/iddpm_precond_conv_sigma.pth b/test/diffusion/data/iddpm_precond_conv_sigma.pth new file mode 100644 index 0000000000..ef7ba483d8 Binary files /dev/null and b/test/diffusion/data/iddpm_precond_conv_sigma.pth differ diff --git a/test/diffusion/data/iddpm_precond_linear.mdlus b/test/diffusion/data/iddpm_precond_linear.mdlus new file mode 100644 index 0000000000..618a77662c Binary files /dev/null and b/test/diffusion/data/iddpm_precond_linear.mdlus differ diff --git a/test/diffusion/data/iddpm_precond_linear_coefficients.pth b/test/diffusion/data/iddpm_precond_linear_coefficients.pth new file mode 100644 index 0000000000..f1e5f473f6 Binary files /dev/null and b/test/diffusion/data/iddpm_precond_linear_coefficients.pth differ diff --git a/test/diffusion/data/iddpm_precond_linear_forward.pth b/test/diffusion/data/iddpm_precond_linear_forward.pth new file mode 100644 index 0000000000..0e11d42fc5 Binary files /dev/null and b/test/diffusion/data/iddpm_precond_linear_forward.pth differ diff --git a/test/diffusion/data/iddpm_precond_linear_sigma.pth b/test/diffusion/data/iddpm_precond_linear_sigma.pth new file mode 100644 index 0000000000..2ddca8da83 Binary files /dev/null and b/test/diffusion/data/iddpm_precond_linear_sigma.pth differ diff --git a/test/diffusion/data/ve_precond_conv.mdlus b/test/diffusion/data/ve_precond_conv.mdlus new file mode 100644 index 0000000000..598f4bab8e Binary files /dev/null and b/test/diffusion/data/ve_precond_conv.mdlus differ diff --git a/test/diffusion/data/ve_precond_conv_coefficients.pth b/test/diffusion/data/ve_precond_conv_coefficients.pth new file mode 100644 index 0000000000..f18464dbb6 Binary files /dev/null and b/test/diffusion/data/ve_precond_conv_coefficients.pth differ diff --git a/test/diffusion/data/ve_precond_conv_forward.pth b/test/diffusion/data/ve_precond_conv_forward.pth new file mode 100644 index 0000000000..1940ffda06 Binary files /dev/null and b/test/diffusion/data/ve_precond_conv_forward.pth differ diff --git a/test/diffusion/data/ve_precond_conv_sigma.pth b/test/diffusion/data/ve_precond_conv_sigma.pth new file mode 100644 index 0000000000..0d16665a4c Binary files /dev/null and b/test/diffusion/data/ve_precond_conv_sigma.pth differ diff --git a/test/diffusion/data/ve_precond_linear.mdlus b/test/diffusion/data/ve_precond_linear.mdlus new file mode 100644 index 0000000000..876927ce7c Binary files /dev/null and b/test/diffusion/data/ve_precond_linear.mdlus differ diff --git a/test/diffusion/data/ve_precond_linear_coefficients.pth b/test/diffusion/data/ve_precond_linear_coefficients.pth new file mode 100644 index 0000000000..dcd3babfb9 Binary files /dev/null and b/test/diffusion/data/ve_precond_linear_coefficients.pth differ diff --git a/test/diffusion/data/ve_precond_linear_forward.pth b/test/diffusion/data/ve_precond_linear_forward.pth new file mode 100644 index 0000000000..7a311f29c2 Binary files /dev/null and b/test/diffusion/data/ve_precond_linear_forward.pth differ diff --git a/test/diffusion/data/ve_precond_linear_sigma.pth b/test/diffusion/data/ve_precond_linear_sigma.pth new file mode 100644 index 0000000000..f8098e6e2f Binary files /dev/null and b/test/diffusion/data/ve_precond_linear_sigma.pth differ diff --git a/test/diffusion/data/vp_precond_conv.mdlus b/test/diffusion/data/vp_precond_conv.mdlus new file mode 100644 index 0000000000..04caac10d6 Binary files /dev/null and b/test/diffusion/data/vp_precond_conv.mdlus differ diff --git a/test/diffusion/data/vp_precond_conv_coefficients.pth b/test/diffusion/data/vp_precond_conv_coefficients.pth new file mode 100644 index 0000000000..0319b45caa Binary files /dev/null and b/test/diffusion/data/vp_precond_conv_coefficients.pth differ diff --git a/test/diffusion/data/vp_precond_conv_forward.pth b/test/diffusion/data/vp_precond_conv_forward.pth new file mode 100644 index 0000000000..0b9a7696c7 Binary files /dev/null and b/test/diffusion/data/vp_precond_conv_forward.pth differ diff --git a/test/diffusion/data/vp_precond_conv_sigma.pth b/test/diffusion/data/vp_precond_conv_sigma.pth new file mode 100644 index 0000000000..bdc2280664 Binary files /dev/null and b/test/diffusion/data/vp_precond_conv_sigma.pth differ diff --git a/test/diffusion/data/vp_precond_linear.mdlus b/test/diffusion/data/vp_precond_linear.mdlus new file mode 100644 index 0000000000..b167dfba6a Binary files /dev/null and b/test/diffusion/data/vp_precond_linear.mdlus differ diff --git a/test/diffusion/data/vp_precond_linear_coefficients.pth b/test/diffusion/data/vp_precond_linear_coefficients.pth new file mode 100644 index 0000000000..c26a48fe10 Binary files /dev/null and b/test/diffusion/data/vp_precond_linear_coefficients.pth differ diff --git a/test/diffusion/data/vp_precond_linear_forward.pth b/test/diffusion/data/vp_precond_linear_forward.pth new file mode 100644 index 0000000000..b61cdb96b8 Binary files /dev/null and b/test/diffusion/data/vp_precond_linear_forward.pth differ diff --git a/test/diffusion/data/vp_precond_linear_sigma.pth b/test/diffusion/data/vp_precond_linear_sigma.pth new file mode 100644 index 0000000000..e4cf2d823b Binary files /dev/null and b/test/diffusion/data/vp_precond_linear_sigma.pth differ diff --git a/test/diffusion/helpers.py b/test/diffusion/helpers.py new file mode 100644 index 0000000000..7333c3ca9c --- /dev/null +++ b/test/diffusion/helpers.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for diffusion preconditioner tests.""" + +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +from tensordict import TensorDict + +import physicsnemo.core + +# Directory for test reference data +DATA_DIR = Path(__file__).parent / "data" + + +def instantiate_model_deterministic( + cls, + seed: int = 0, + **kwargs: Any, +) -> physicsnemo.core.Module: + """ + Instantiate a model with deterministic random parameters. + """ + model = cls(**kwargs) + gen = torch.Generator(device="cpu") + gen.manual_seed(seed) + with torch.no_grad(): + for param in model.parameters(): + param.copy_( + torch.randn( + param.shape, + generator=gen, + dtype=param.dtype, + ) + ) + return model + + +def generate_batch_data( + shape: Tuple[int, ...] = (4, 3, 16, 16), + seed: int = 42, + device: str = "cpu", + use_condition: bool = False, +) -> Dict[str, torch.Tensor | TensorDict]: + """ + Generate deterministic batch data for testing. + + Parameters + ---------- + shape : Tuple[int, ...] + Shape of the input tensor x. + seed : int + Random seed for deterministic generation. + device : str + Device to place tensors on. + use_condition : bool + If True, generates condition["y"] with the same shape as x. + + Returns + ------- + Dict containing: + - "x": Input tensor of given shape + - "t": Time tensor of shape (batch_size,) + - "condition": TensorDict with batch_size matching x + """ + gen = torch.Generator(device="cpu") + gen.manual_seed(seed) + + batch_size = shape[0] + x = torch.randn(*shape, generator=gen) + # Use positive t values away from 0 to avoid log(0) issues + t = torch.rand(batch_size, generator=gen) * 0.5 + 0.4 + + # Generate condition as TensorDict with batch_size + if use_condition: + condition = TensorDict( + {"y": torch.randn(*shape, generator=gen).to(device)}, + batch_size=[batch_size], + ) + else: + condition = TensorDict({}, batch_size=[batch_size]) + + return { + "x": x.to(device), + "t": t.to(device), + "condition": condition.to(device), + } + + +def load_or_create_reference( + file_name: str, + compute_fn: Optional[Callable[[], Dict[str, torch.Tensor]]], + *, + force_recreate: bool = False, +) -> Dict[str, torch.Tensor]: + """ + Load reference data from file, or create it if it doesn't exist. + + Parameters + ---------- + file_name : str + Name of the reference data file (relative to DATA_DIR). + compute_fn : Callable[[], Dict[str, torch.Tensor]] + Function that computes and returns the reference data dictionary. + Called only when reference data needs to be created. + force_recreate : bool, optional + If True, recreate the reference data even if it exists, + by default False. + + Returns + ------- + Dict[str, torch.Tensor] + The reference data dictionary. + """ + file_path = DATA_DIR / file_name + + if file_path.exists() and not force_recreate: + return torch.load(file_path, weights_only=True) + + # Create data directory if it doesn't exist + DATA_DIR.mkdir(parents=True, exist_ok=True) + + # Compute reference data + if compute_fn is None: + raise FileNotFoundError( + f"Reference data not found: {file_path}. " + f"Run test with compute_fn to create it first." + ) + data = compute_fn() + + # Move all tensors to CPU before saving + data_cpu = {} + for k, v in data.items(): + if isinstance(v, torch.Tensor): + data_cpu[k] = v.cpu() + else: + data_cpu[k] = v + + # Save reference data + torch.save(data_cpu, file_path) + + return data + + +def load_or_create_checkpoint( + checkpoint_name: str, + create_fn: Optional[Callable[[], physicsnemo.core.Module]], + force_recreate: bool = False, +) -> physicsnemo.core.Module: + """ + Load checkpoint from file, or create it if it doesn't exist. + """ + checkpoint_path = DATA_DIR / checkpoint_name + + if not checkpoint_path.exists() or force_recreate: + DATA_DIR.mkdir(parents=True, exist_ok=True) + if create_fn is None: + raise FileNotFoundError( + f"Checkpoint not found: {checkpoint_path}. " + f"Run test with create_fn to create it first." + ) + model = create_fn() + model.save(str(checkpoint_path)) + return model + else: + return physicsnemo.core.Module.from_checkpoint(str(checkpoint_path)) + + +def compare_outputs( + actual: torch.Tensor, + expected: torch.Tensor, + atol: float = 1e-5, + rtol: float = 1e-5, +) -> None: + """ + Compare actual and expected tensors with detailed error reporting. + + Parameters + ---------- + actual : torch.Tensor + The computed tensor. + expected : torch.Tensor + The expected reference tensor. + atol : float, optional + Absolute tolerance, by default 1e-5. + rtol : float, optional + Relative tolerance, by default 1e-5. + + Raises + ------ + AssertionError + If tensors don't match within tolerance, with detailed error info. + """ + if actual.shape != expected.shape: + raise AssertionError( + f"Shape mismatch: actual {actual.shape} vs expected {expected.shape}" + ) + + # Move to same device and convert to float64 for comparison + actual_f64 = actual.to(torch.float64) + expected_f64 = expected.to(device=actual.device, dtype=torch.float64) + + torch.testing.assert_close(actual_f64, expected_f64, atol=atol, rtol=rtol) diff --git a/test/diffusion/test_preconditioners.py b/test/diffusion/test_preconditioners.py new file mode 100644 index 0000000000..1cca2bb8b3 --- /dev/null +++ b/test/diffusion/test_preconditioners.py @@ -0,0 +1,658 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for diffusion preconditioners.""" + +from typing import Any, Tuple + +import pytest +import torch +from tensordict import TensorDict + +from physicsnemo.core import Module +from physicsnemo.diffusion.preconditioners import ( + BaseAffinePreconditioner, + EDMPreconditioner, + IDDPMPreconditioner, + VEPreconditioner, + VPPreconditioner, +) + +from .helpers import ( + compare_outputs, + generate_batch_data, + instantiate_model_deterministic, + load_or_create_checkpoint, + load_or_create_reference, +) + +# ============================================================================= +# Test Model Definitions +# ============================================================================= + + +class ConvModel(Module): + """Convolutional model for testing preconditioners with 4D input.""" + + def __init__(self, channels: int = 3): + super().__init__() + self.channels = channels + # Conv2d takes x concatenated with condition["y"] (same shape as x) + in_channels = channels * 2 + self.net = torch.nn.Conv2d(in_channels, channels, kernel_size=1) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + condition: TensorDict, + **kwargs: Any, + ) -> torch.Tensor: + y = condition["y"] + x_cond = torch.cat([x, y], dim=1) + out = self.net(x_cond) + t_scale = t.view(-1, 1, 1, 1) + return out + t_scale + + +class LinearModel(Module): + """Linear model for testing preconditioners with 2D input.""" + + def __init__(self, in_features: int = 64): + super().__init__() + self.in_features = in_features + # Simple linear layer that preserves dimension + self.net = torch.nn.Linear(in_features, in_features) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + condition: TensorDict, + **kwargs: Any, + ) -> torch.Tensor: + out = self.net(x) + t_scale = t.view(-1, 1) + return out + t_scale + + +# ============================================================================= +# Constants and Preconditioner Configurations +# ============================================================================= + + +# Test shapes for different model types +# 4D shape for ConvModel: (batch_size, channels, height, width) +CONV_SHAPE: Tuple[int, ...] = (4, 3, 8, 6) +# 2D shape for LinearModel: (batch_size, features) +LINEAR_SHAPE: Tuple[int, ...] = (4, 16) + +# Model configurations for parameterized tests: (model_class, shape, arch_name) +MODEL_CONFIGS = [ + (ConvModel, CONV_SHAPE, "conv"), + (LinearModel, LINEAR_SHAPE, "linear"), +] + +# Preconditioner configurations for parameterized tests +PRECOND_CONFIGS = [ + ( + VPPreconditioner, + {"beta_d": 19.9, "beta_min": 0.1, "M": 2000}, + "vp_precond", + ), + ( + VEPreconditioner, + {}, + "ve_precond", + ), + ( + IDDPMPreconditioner, + {"C_1": 0.001, "C_2": 0.008, "M": 2000}, + "iddpm_precond", + ), + ( + EDMPreconditioner, + {"sigma_data": 1.0}, + "edm_precond", + ), +] + + +# Tolerances for non-regression tests (device-dependent) +# CPU tests use tighter tolerances, GPU tests need more relaxed tolerances +CPU_TOLERANCES = {"atol": 1e-5, "rtol": 1e-5} +GPU_TOLERANCES = {"atol": 1e-2, "rtol": 5e-2} + +# Global random seed for reproducibility +GLOBAL_SEED = 42 + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def deterministic_settings(): + """Set deterministic settings for reproducibility, then restore old state.""" + # Save old state + old_cudnn_deterministic = torch.backends.cudnn.deterministic + old_cudnn_benchmark = torch.backends.cudnn.benchmark + old_matmul_tf32 = torch.backends.cuda.matmul.allow_tf32 + old_cudnn_tf32 = torch.backends.cudnn.allow_tf32 + + try: + # Set deterministic settings + torch.manual_seed(GLOBAL_SEED) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(GLOBAL_SEED) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + yield + finally: + # Restore old state + torch.backends.cudnn.deterministic = old_cudnn_deterministic + torch.backends.cudnn.benchmark = old_cudnn_benchmark + torch.backends.cuda.matmul.allow_tf32 = old_matmul_tf32 + torch.backends.cudnn.allow_tf32 = old_cudnn_tf32 + + +@pytest.fixture +def tolerances(device): + """Return tolerances based on the device (CPU vs GPU).""" + if device == "cpu": + return CPU_TOLERANCES + return GPU_TOLERANCES + + +@pytest.fixture(params=MODEL_CONFIGS, ids=["ConvModel", "LinearModel"]) +def model_config(request): + """Parameterized fixture returning (model_class, shape, arch_name).""" + return request.param + + +@pytest.fixture +def test_shape(model_config): + """Return the test shape for the current model class.""" + _, shape, _ = model_config + return shape + + +@pytest.fixture +def arch_name(model_config): + """Return the architecture name for reference file naming.""" + _, _, name = model_config + return name + + +@pytest.fixture +def simple_model(model_config): + """Create a model with deterministic parameters.""" + cls, shape, _ = model_config + if cls == LinearModel: + return instantiate_model_deterministic(cls, seed=0, in_features=shape[1]) + return instantiate_model_deterministic(cls, seed=0, channels=shape[1]) + + +@pytest.fixture +def batch_data(model_config, device): + """Create deterministic batch data matching the model's expected shape.""" + model_cls, shape, _ = model_config + # ConvModel uses condition, LinearModel does not + use_condition = model_cls == ConvModel + return generate_batch_data( + shape=shape, seed=42, device=device, use_condition=use_condition + ) + + +def create_model_deterministic(model_cls, shape): + """Create a model with deterministic parameters for the given shape.""" + if model_cls == LinearModel: + return instantiate_model_deterministic(model_cls, seed=0, in_features=shape[1]) + return instantiate_model_deterministic(model_cls, seed=0, channels=shape[1]) + + +def create_preconditioner(precond_cls, precond_kwargs, model_cls, shape): + """Create a preconditioner with deterministic model.""" + model = create_model_deterministic(model_cls, shape) + return precond_cls(model, **precond_kwargs) + + +# ============================================================================= +# VPPreconditioner Tests +# ============================================================================= + + +class TestVPPreconditioner: + """Tests for VPPreconditioner.""" + + @pytest.mark.parametrize( + "config,beta_d,beta_min,M", + [ + ("default", 19.9, 0.1, 1000), + ("custom", 10.0, 0.05, 500), + ], + ids=["default", "custom"], + ) + def test_constructor_attributes(self, simple_model, config, beta_d, beta_min, M): + """Test VPPreconditioner constructor and attributes.""" + if config == "default": + # Test with default values - verify against known defaults + precond = VPPreconditioner(simple_model) + assert precond.beta_d.item() == pytest.approx(19.9) + assert precond.beta_min.item() == pytest.approx(0.1) + assert precond.M.item() == 1000 + else: + # Test with custom values - verify against passed arguments + precond = VPPreconditioner( + simple_model, beta_d=beta_d, beta_min=beta_min, M=M + ) + assert precond.beta_d.item() == pytest.approx(beta_d) + assert precond.beta_min.item() == pytest.approx(beta_min) + assert precond.M.item() == M + + assert precond.model is simple_model + assert isinstance(precond, BaseAffinePreconditioner) + + +# ============================================================================= +# VEPreconditioner Tests +# ============================================================================= + + +class TestVEPreconditioner: + """Tests for VEPreconditioner.""" + + def test_constructor_attributes(self, simple_model): + """Test VEPreconditioner constructor and attributes.""" + precond = VEPreconditioner(simple_model) + + assert precond.model is simple_model + assert isinstance(precond, BaseAffinePreconditioner) + + +# ============================================================================= +# IDDPMPreconditioner Tests +# ============================================================================= + + +class TestIDDPMPreconditioner: + """Tests for IDDPMPreconditioner.""" + + @pytest.mark.parametrize( + "config,C_1,C_2,M", + [ + ("default", 0.001, 0.008, 1000), + ("custom", 0.002, 0.01, 500), + ], + ids=["default", "custom"], + ) + def test_constructor_attributes(self, simple_model, config, C_1, C_2, M): + """Test IDDPMPreconditioner constructor and attributes.""" + if config == "default": + # Test with default values - verify against known defaults + precond = IDDPMPreconditioner(simple_model) + assert precond.C_1.item() == pytest.approx(0.001) + assert precond.C_2.item() == pytest.approx(0.008) + assert precond.M.item() == 1000 + expected_M = 1000 + else: + # Test with custom values - verify against passed arguments + precond = IDDPMPreconditioner(simple_model, C_1=C_1, C_2=C_2, M=M) + assert precond.C_1.item() == pytest.approx(C_1) + assert precond.C_2.item() == pytest.approx(C_2) + assert precond.M.item() == M + expected_M = M + + assert hasattr(precond, "u") + assert precond.u.shape == (expected_M + 1,) + assert isinstance(precond, BaseAffinePreconditioner) + + +# ============================================================================= +# EDMPreconditioner Tests +# ============================================================================= + + +class TestEDMPreconditioner: + """Tests for EDMPreconditioner.""" + + @pytest.mark.parametrize( + "config,sigma_data", + [ + ("default", 0.5), + ("custom", 1.0), + ], + ids=["default", "custom"], + ) + def test_constructor_attributes(self, simple_model, config, sigma_data): + """Test EDMPreconditioner constructor and attributes.""" + if config == "default": + # Test with default values - verify against known defaults + precond = EDMPreconditioner(simple_model) + assert precond.sigma_data.item() == pytest.approx(0.5) + else: + # Test with custom values - verify against passed arguments + precond = EDMPreconditioner(simple_model, sigma_data=sigma_data) + assert precond.sigma_data.item() == pytest.approx(sigma_data) + + assert precond.model is simple_model + assert isinstance(precond, BaseAffinePreconditioner) + + +# ============================================================================= +# Non-Regression Tests (Parameterized Across All Preconditioners and Models) +# ============================================================================= + + +@pytest.mark.parametrize( + "precond_cls,precond_kwargs,precond_name", + PRECOND_CONFIGS, + ids=["VP", "VE", "iDDPM", "EDM"], +) +class TestNonRegression: + """Non-regression tests parameterized across all preconditioner types.""" + + def test_sigma_non_regression( + self, + deterministic_settings, + model_config, + batch_data, + device, + tolerances, + precond_cls, + precond_kwargs, + precond_name, + ): + """Test sigma(t) against reference data.""" + model_cls, shape, arch_name = model_config + precond = create_preconditioner( + precond_cls, precond_kwargs, model_cls, shape + ).to(device) + + t = batch_data["t"] + sigma = precond.sigma(t) + + ref_file = f"{precond_name}_{arch_name}_sigma.pth" + ref_data = load_or_create_reference(ref_file, lambda: {"sigma": sigma.cpu()}) + + compare_outputs(sigma, ref_data["sigma"], **tolerances) + + def test_sigma_from_checkpoint( + self, + deterministic_settings, + model_config, + batch_data, + device, + tolerances, + precond_cls, + precond_kwargs, + precond_name, + ): + """Test sigma(t) from loaded checkpoint matches reference.""" + model_cls, shape, arch_name = model_config + + def create_fn(): + return create_preconditioner(precond_cls, precond_kwargs, model_cls, shape) + + ckpt_file = f"{precond_name}_{arch_name}.mdlus" + precond = load_or_create_checkpoint(ckpt_file, create_fn).to(device) + + t = batch_data["t"] + sigma = precond.sigma(t) + + ref_file = f"{precond_name}_{arch_name}_sigma.pth" + ref_data = load_or_create_reference(ref_file, lambda: {"sigma": sigma.cpu()}) + + compare_outputs(sigma, ref_data["sigma"], **tolerances) + + def test_coefficients_non_regression( + self, + deterministic_settings, + model_config, + batch_data, + device, + tolerances, + precond_cls, + precond_kwargs, + precond_name, + ): + """Test compute_coefficients against reference data.""" + model_cls, shape, arch_name = model_config + precond = create_preconditioner( + precond_cls, precond_kwargs, model_cls, shape + ).to(device) + + # Reshape t to sigma shape: (B, 1, ..., 1) + batch_size = shape[0] + sigma_shape = (batch_size,) + (1,) * (len(shape) - 1) + sigma = batch_data["t"].view(sigma_shape) + + c_in, c_noise, c_out, c_skip = precond.compute_coefficients(sigma) + + # Load existing reference or save current output as reference + ref_file = f"{precond_name}_{arch_name}_coefficients.pth" + ref_data = load_or_create_reference( + ref_file, + lambda: { + "c_in": c_in.cpu(), + "c_noise": c_noise.cpu(), + "c_out": c_out.cpu(), + "c_skip": c_skip.cpu(), + }, + ) + + compare_outputs(c_in, ref_data["c_in"], **tolerances) + compare_outputs(c_noise, ref_data["c_noise"], **tolerances) + compare_outputs(c_out, ref_data["c_out"], **tolerances) + compare_outputs(c_skip, ref_data["c_skip"], **tolerances) + + def test_coefficients_from_checkpoint( + self, + deterministic_settings, + model_config, + batch_data, + device, + tolerances, + precond_cls, + precond_kwargs, + precond_name, + ): + """Test compute_coefficients from checkpoint matches reference.""" + model_cls, shape, arch_name = model_config + + def create_fn(): + return create_preconditioner(precond_cls, precond_kwargs, model_cls, shape) + + ckpt_file = f"{precond_name}_{arch_name}.mdlus" + precond = load_or_create_checkpoint(ckpt_file, create_fn).to(device) + + # Reshape t to sigma shape: (B, 1, ..., 1) + batch_size = shape[0] + sigma_shape = (batch_size,) + (1,) * (len(shape) - 1) + sigma = batch_data["t"].view(sigma_shape) + + c_in, c_noise, c_out, c_skip = precond.compute_coefficients(sigma) + + ref_file = f"{precond_name}_{arch_name}_coefficients.pth" + ref_data = load_or_create_reference( + ref_file, + lambda: { + "c_in": c_in.cpu(), + "c_noise": c_noise.cpu(), + "c_out": c_out.cpu(), + "c_skip": c_skip.cpu(), + }, + ) + + compare_outputs(c_in, ref_data["c_in"], **tolerances) + compare_outputs(c_noise, ref_data["c_noise"], **tolerances) + compare_outputs(c_out, ref_data["c_out"], **tolerances) + compare_outputs(c_skip, ref_data["c_skip"], **tolerances) + + def test_forward_non_regression( + self, + deterministic_settings, + model_config, + batch_data, + device, + tolerances, + precond_cls, + precond_kwargs, + precond_name, + ): + """Test forward pass against reference data.""" + model_cls, shape, arch_name = model_config + precond = create_preconditioner( + precond_cls, precond_kwargs, model_cls, shape + ).to(device) + + x = batch_data["x"] + t = batch_data["t"] + condition = batch_data["condition"] + out = precond(x, t, condition) + + ref_file = f"{precond_name}_{arch_name}_forward.pth" + ref_data = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + + compare_outputs(out, ref_data["out"], **tolerances) + + def test_forward_from_checkpoint( + self, + deterministic_settings, + model_config, + batch_data, + device, + tolerances, + precond_cls, + precond_kwargs, + precond_name, + ): + """Test forward pass from loaded checkpoint matches reference.""" + model_cls, shape, arch_name = model_config + + def create_fn(): + return create_preconditioner(precond_cls, precond_kwargs, model_cls, shape) + + ckpt_file = f"{precond_name}_{arch_name}.mdlus" + precond = load_or_create_checkpoint(ckpt_file, create_fn).to(device) + + x = batch_data["x"] + t = batch_data["t"] + condition = batch_data["condition"] + out = precond(x, t, condition) + + ref_file = f"{precond_name}_{arch_name}_forward.pth" + ref_data = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + + compare_outputs(out, ref_data["out"], **tolerances) + + +# ============================================================================= +# Other tests for all preconditioner types +# ============================================================================= + + +@pytest.mark.parametrize( + "precond_cls,precond_kwargs,precond_name", + PRECOND_CONFIGS, + ids=["VP", "VE", "iDDPM", "EDM"], +) +class TestAllPreconditioners: + """Tests that apply to all preconditioner types.""" + + def test_forward_input_validation( + self, + simple_model, + batch_data, + device, + precond_cls, + precond_kwargs, + precond_name, + ): + """Test forward validates input shapes.""" + precond = precond_cls(simple_model, **precond_kwargs).to(device) + x = batch_data["x"] + t_wrong = torch.rand(2, device=device) # Wrong batch size + condition = batch_data["condition"] + + with pytest.raises(ValueError, match="Expected t to have shape"): + precond(x, t_wrong, condition) + + def test_forward_dtype_preservation( + self, + simple_model, + batch_data, + device, + precond_cls, + precond_kwargs, + precond_name, + ): + """Test forward preserves input dtype.""" + precond = precond_cls(simple_model, **precond_kwargs).to(device) + x = batch_data["x"] + t = batch_data["t"] + condition = batch_data["condition"] + + output = precond(x, t, condition) + + assert output.dtype == x.dtype + + def test_condition_batch_validation( + self, + simple_model, + test_shape, + device, + precond_cls, + precond_kwargs, + precond_name, + ): + """Test condition batch size validation.""" + precond = precond_cls(simple_model, **precond_kwargs).to(device) + x = torch.randn(*test_shape, device=device) + t = torch.rand(test_shape[0], device=device) + # Wrong batch size in condition TensorDict + condition = TensorDict( + {"cond": torch.randn(2, 10, device=device)}, + batch_size=[2], + ) + + with pytest.raises(ValueError, match="batch size"): + precond(x, t, condition) + + def test_gradient_flow( + self, + simple_model, + batch_data, + device, + precond_cls, + precond_kwargs, + precond_name, + ): + """Test gradients flow through the preconditioner.""" + precond = precond_cls(simple_model, **precond_kwargs).to(device) + x = batch_data["x"].clone().requires_grad_(True) + t = batch_data["t"] + condition = batch_data["condition"] + + output = precond(x, t, condition) + loss = output.sum() + loss.backward() + + assert x.grad is not None + assert not torch.isnan(x.grad).any()