Skip to content

Conversation

@CharlelieLrt
Copy link
Collaborator

@CharlelieLrt CharlelieLrt commented Jan 8, 2026

PhysicsNeMo Pull Request

Overview

This PR refactors the diffusion model preconditioners to introduce a clean, extensible architecture. The goal is to standardize how preconditioners wrap neural network models and apply the preconditioning formula, making it easier to implement new preconditioning schemes and maintain existing ones.

  • New BasePreconditioner abstract base class in physicsnemo/diffusion/preconditioners/preconditioners.py that defines a standardized interface for wrapping diffusion models preconditioning. The base class handles the common forward pass logic while subclasses implement compute_coefficients() to define their specific preconditioning scheme. Subclasses can optionally override sigma() to implement custom noise schedules (time-to-noise mappings). A key improvement of this refactor is to enable dependency-injection design patterns by wrapping arbitrary physicsnemo.Module instance with a preconditioner.

  • Four preconditioner reimplementations based on the new BasePreconditioner: VPPreconditioner, VEPreconditioner, IDDPMPreconditioner, and EDMPreconditioner. These are cleaner, standalone versions of the existing legacy preconditioners with comprehensive docstrings including mathematical formulas for the preconditioning coefficients and noise schedules.

  • Migrated legacy preconditioners in legacy.py to inherit from the new base classes. This eliminates code duplication while maintaining full backward compatibility—all existing method signatures, attributes, and behaviors are preserved. Users of the legacy API do not need to change their code.

  • Comprehensive CI tests in test/diffusion/test_preconditioners.py covering:

    • Constructor instantiation and attribute verification
    • sigma() and compute_coefficients() methods
    • Forward pass with non-regression testing against saved reference data
    • Checkpoint save/load roundtrip via physicsnemo.Module.from_checkpoint()
  • Closes 🚀[FEA]: Allow passing models as to EDMPrecond #796 .

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Refactors diffusion preconditioners by introducing a clean BasePreconditioner abstract class that standardizes how preconditioners wrap neural networks and apply preconditioning formulas. Implements four preconditioners (VPPreconditioner, VEPreconditioner, IDDPMPreconditioner, EDMPreconditioner) as standalone classes with comprehensive docstrings including mathematical formulas. Migrates legacy preconditioners to inherit from new base classes, eliminating code duplication while maintaining full backward compatibility with existing APIs.

Important Files Changed

File Analysis

Filename Score Overview
physicsnemo/diffusion/preconditioners/preconditioners.py 4/5 Introduces clean BasePreconditioner architecture and four preconditioner implementations with comprehensive docstrings
physicsnemo/diffusion/preconditioners/legacy.py 4/5 Migrates legacy preconditioners to inherit from new base classes while maintaining full backward compatibility
test/diffusion/test_preconditioners.py 4/5 Comprehensive tests covering constructors, sigma/coefficients methods, forward pass, and checkpoint loading with non-regression testing

@CharlelieLrt CharlelieLrt self-assigned this Jan 8, 2026
@CharlelieLrt CharlelieLrt added the 3 - Ready for Review Ready for review by team label Jan 8, 2026
condition : Dict[str, torch.Tensor]
Dictionary of conditioning tensors. Each tensor must have shape
:math:`(B, *)` where the batch size :math:`B` matches that of ``x``.
These are passed to the wrapped ``model`` without modification.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is should be marked optional, no? What about unconditional diffusion models? Similarly the underlying models themselves should not have to require a condition argument

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was thinking about making condition optional, but if we make it an optional argument, how do we separate it from those in **model_kwargs? There might be conflicts and mix-up between the two

Similarly the underlying models themselves should not have to require a condition argument

Right, but then similar problem: how do we differentiate between models that require one and others that don't?
I found making condition required was the cleanest solution to remove all type of confusion and ambiguity, even though this argument is not needed for many cases.

But if you have a better ideaI I am open to it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, yeah good points the model kwargs make this annoying. It seems hard to avoid some sort of awkwardness here, apart from flat out defining separate preconditioners for the conditional and unconditional cases. To me it just feels wrong to have to wrap underlying unconditional backbones with something to pass a dummy condition arg, and similarly for the top-level preconditioner, but I'm having trouble coming up with alternate solutions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of a capability flag passed to preconditioner init conditional which will specify whether or not the preconditioner (and underlying model) are expected to use the condition arg? Then we could have

forward(x, condition: TensorDict | None = None, **model_kwargs)

where within the forward pass we call things based on the value of self.conditional.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternate, and possibly spicier suggestion, we drop the mention of condition entirely from the forward signature. It is absorbed into model_kwargs, may or may not be passed, and is up to the user to do the input validation (which is the only operation on condition within the forward of the preconditioner). This would also allow for flexible nomenclature of the conditioning in principle, i.e. some bespoke model wants to call it era5_condition(+ optionally add others liketime_of_day_condition`, etc), then it is welcome to.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(That is unless in the samplers theres some specific operation that needs conditional fields explicitly)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of a capability flag passed to preconditioner init conditional which will specify whether or not the preconditioner (and underlying model) are expected to use the condition arg?

@pzharrington if it were only for the precondtioner, I would say okay. But we would need this conditional flag everywhere the model is passed as a callback (loss, samplers, etc). That will make things very heavy IMO. Also loss object is supposed to be purely functional, so the flag would need to be passed to the __call__ and not the __init__.

I prefer the second alternative that you proposed. Another option I just thought of would be to make the condition argument keyword-only, with a signature model(x, t, *, condition=None, **model_kwargs). That removes some possible ambiguity and mix up because by forcing to always pass condition by name, we forbid calls such as model(x, t, condition={...}, kwargs1=..., condition={...}, other_kwarg=...), which are anyways invalid.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickGeneva 's comment above reminded another reason why I wanted to have an explicit argument for condition. In multi-diffusion (whose implementation is still at the stage of philosophical reflection as of now), the model needs to know which argument is condition, because it needs to apply specific operations on the conditioning tensors (patching, sometimes interpolation). So, in multi-diffusion, condition cannot be considered the same as any model_kwargs.

Roughly, it should like:

multi_diffusion_model = MultiDiffusionModel(model, patching_options)
x0 = multi_diffusion_model(x, t, condition={"y": y, "z": z}, some_kwarg=some_val)

Under the hood the forward pass to multi_diffusion_model applies patching (+ optionally interpolation) and concatenation to the items in condition, but it leaves all other kwargs untouched.

Copy link
Collaborator

@pzharrington pzharrington Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In multi-diffusion, the model needs to know which argument is condition

And it would, right? The underlying model would define condition in its forward kwargs and handle accordingly within its forward pass. Under my second suggestion, the preconditioner wrapping that would simply pass it through as part of the **model_kwargs, it doesn't need to know about or explicitly do anything with the condition (there is no condition-dependent preconditioning and if there was, that's crazy 😅).

E.g., in your snippet, we'd have

multi_diffusion_model = MultiDiffusionModel(model, patching_options) # Wrap a base backbone with multidiffusion
model_precond = EDMPreconditioner(multi_diffusion_model) # Wrapped by preconditioner, ready for training
x0 = model_precond(x, t, condition={"y": y, "z": z}, some_kwarg=some_val)

Preconditioner doesn't care about the conditioning or the patching/interpolation applied to it, it delegates that to the underlying models. This is fine, no? Or am I missing something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I'm trying to understand that...

In your snippet above, what would be the signatures of model, model_precond, and multi_diffusion_model ? (And I mean the actual signatures, not just the way they are called, because model(x, t, condition={}, kwarg=val) may look the same whether condition is a separate keyword argument or part of the kwargs, but the signature signals intent)

@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

3 - Ready for Review Ready for review by team

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🚀[FEA]: Allow passing models as to EDMPrecond

7 participants