-
Notifications
You must be signed in to change notification settings - Fork 540
Diffusion preconditioners refactor #1317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Diffusion preconditioners refactor #1317
Conversation
No tests fixed yet.
phsyicsnemo.utils, launch.config is just gone. It was empty.
Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Refactors diffusion preconditioners by introducing a clean BasePreconditioner abstract class that standardizes how preconditioners wrap neural networks and apply preconditioning formulas. Implements four preconditioners (VPPreconditioner, VEPreconditioner, IDDPMPreconditioner, EDMPreconditioner) as standalone classes with comprehensive docstrings including mathematical formulas. Migrates legacy preconditioners to inherit from new base classes, eliminating code duplication while maintaining full backward compatibility with existing APIs.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| physicsnemo/diffusion/preconditioners/preconditioners.py | 4/5 | Introduces clean BasePreconditioner architecture and four preconditioner implementations with comprehensive docstrings |
| physicsnemo/diffusion/preconditioners/legacy.py | 4/5 | Migrates legacy preconditioners to inherit from new base classes while maintaining full backward compatibility |
| test/diffusion/test_preconditioners.py | 4/5 | Comprehensive tests covering constructors, sigma/coefficients methods, forward pass, and checkpoint loading with non-regression testing |
Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
| condition : Dict[str, torch.Tensor] | ||
| Dictionary of conditioning tensors. Each tensor must have shape | ||
| :math:`(B, *)` where the batch size :math:`B` matches that of ``x``. | ||
| These are passed to the wrapped ``model`` without modification. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is should be marked optional, no? What about unconditional diffusion models? Similarly the underlying models themselves should not have to require a condition argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was thinking about making condition optional, but if we make it an optional argument, how do we separate it from those in **model_kwargs? There might be conflicts and mix-up between the two
Similarly the underlying models themselves should not have to require a condition argument
Right, but then similar problem: how do we differentiate between models that require one and others that don't?
I found making condition required was the cleanest solution to remove all type of confusion and ambiguity, even though this argument is not needed for many cases.
But if you have a better ideaI I am open to it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ugh, yeah good points the model kwargs make this annoying. It seems hard to avoid some sort of awkwardness here, apart from flat out defining separate preconditioners for the conditional and unconditional cases. To me it just feels wrong to have to wrap underlying unconditional backbones with something to pass a dummy condition arg, and similarly for the top-level preconditioner, but I'm having trouble coming up with alternate solutions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of a capability flag passed to preconditioner init conditional which will specify whether or not the preconditioner (and underlying model) are expected to use the condition arg? Then we could have
forward(x, condition: TensorDict | None = None, **model_kwargs)
where within the forward pass we call things based on the value of self.conditional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternate, and possibly spicier suggestion, we drop the mention of condition entirely from the forward signature. It is absorbed into model_kwargs, may or may not be passed, and is up to the user to do the input validation (which is the only operation on condition within the forward of the preconditioner). This would also allow for flexible nomenclature of the conditioning in principle, i.e. some bespoke model wants to call it era5_condition(+ optionally add others liketime_of_day_condition`, etc), then it is welcome to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(That is unless in the samplers theres some specific operation that needs conditional fields explicitly)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of a capability flag passed to preconditioner init conditional which will specify whether or not the preconditioner (and underlying model) are expected to use the condition arg?
@pzharrington if it were only for the precondtioner, I would say okay. But we would need this conditional flag everywhere the model is passed as a callback (loss, samplers, etc). That will make things very heavy IMO. Also loss object is supposed to be purely functional, so the flag would need to be passed to the __call__ and not the __init__.
I prefer the second alternative that you proposed. Another option I just thought of would be to make the condition argument keyword-only, with a signature model(x, t, *, condition=None, **model_kwargs). That removes some possible ambiguity and mix up because by forcing to always pass condition by name, we forbid calls such as model(x, t, condition={...}, kwargs1=..., condition={...}, other_kwarg=...), which are anyways invalid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NickGeneva 's comment above reminded another reason why I wanted to have an explicit argument for condition. In multi-diffusion (whose implementation is still at the stage of philosophical reflection as of now), the model needs to know which argument is condition, because it needs to apply specific operations on the conditioning tensors (patching, sometimes interpolation). So, in multi-diffusion, condition cannot be considered the same as any model_kwargs.
Roughly, it should like:
multi_diffusion_model = MultiDiffusionModel(model, patching_options)
x0 = multi_diffusion_model(x, t, condition={"y": y, "z": z}, some_kwarg=some_val)Under the hood the forward pass to multi_diffusion_model applies patching (+ optionally interpolation) and concatenation to the items in condition, but it leaves all other kwargs untouched.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In multi-diffusion, the model needs to know which argument is condition
And it would, right? The underlying model would define condition in its forward kwargs and handle accordingly within its forward pass. Under my second suggestion, the preconditioner wrapping that would simply pass it through as part of the **model_kwargs, it doesn't need to know about or explicitly do anything with the condition (there is no condition-dependent preconditioning and if there was, that's crazy 😅).
E.g., in your snippet, we'd have
multi_diffusion_model = MultiDiffusionModel(model, patching_options) # Wrap a base backbone with multidiffusion
model_precond = EDMPreconditioner(multi_diffusion_model) # Wrapped by preconditioner, ready for training
x0 = model_precond(x, t, condition={"y": y, "z": z}, some_kwarg=some_val)
Preconditioner doesn't care about the conditioning or the patching/interpolation applied to it, it delegates that to the underlying models. This is fine, no? Or am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I'm trying to understand that...
In your snippet above, what would be the signatures of model, model_precond, and multi_diffusion_model ? (And I mean the actual signatures, not just the way they are called, because model(x, t, condition={}, kwarg=val) may look the same whether condition is a separate keyword argument or part of the kwargs, but the signature signals intent)
Signed-off-by: Charlelie Laurent <[email protected]>
…d of python float Signed-off-by: Charlelie Laurent <[email protected]>
|
/blossom-ci |
Signed-off-by: Charlelie Laurent <[email protected]>
|
/blossom-ci |
Signed-off-by: Charlelie Laurent <[email protected]>
|
/blossom-ci |
Signed-off-by: Charlelie Laurent <[email protected]>
Signed-off-by: Charlelie Laurent <[email protected]>
|
/blossom-ci |
PhysicsNeMo Pull Request
Overview
This PR refactors the diffusion model preconditioners to introduce a clean, extensible architecture. The goal is to standardize how preconditioners wrap neural network models and apply the preconditioning formula, making it easier to implement new preconditioning schemes and maintain existing ones.
New
BasePreconditionerabstract base class inphysicsnemo/diffusion/preconditioners/preconditioners.pythat defines a standardized interface for wrapping diffusion models preconditioning. The base class handles the common forward pass logic while subclasses implementcompute_coefficients()to define their specific preconditioning scheme. Subclasses can optionally overridesigma()to implement custom noise schedules (time-to-noise mappings). A key improvement of this refactor is to enable dependency-injection design patterns by wrapping arbitraryphysicsnemo.Moduleinstance with a preconditioner.Four preconditioner reimplementations based on the new
BasePreconditioner:VPPreconditioner,VEPreconditioner,IDDPMPreconditioner, andEDMPreconditioner. These are cleaner, standalone versions of the existing legacy preconditioners with comprehensive docstrings including mathematical formulas for the preconditioning coefficients and noise schedules.Migrated legacy preconditioners in
legacy.pyto inherit from the new base classes. This eliminates code duplication while maintaining full backward compatibility—all existing method signatures, attributes, and behaviors are preserved. Users of the legacy API do not need to change their code.Comprehensive CI tests in
test/diffusion/test_preconditioners.pycovering:sigma()andcompute_coefficients()methodsphysicsnemo.Module.from_checkpoint()Closes 🚀[FEA]: Allow passing models as to EDMPrecond #796 .
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.