-
Notifications
You must be signed in to change notification settings - Fork 559
Diffusion preconditioners refactor #1317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Diffusion preconditioners refactor #1317
Conversation
No tests fixed yet.
phsyicsnemo.utils, launch.config is just gone. It was empty.
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
…d of python float Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
|
/blossom-ci |
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
|
/blossom-ci |
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
|
/blossom-ci |
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
|
/blossom-ci |
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
|
/blossom-ci |
…rotocol Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
…tioner Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
| t: Float[torch.Tensor, "B"], # noqa: F821 | ||
| condition: Float[torch.Tensor, "B *cond_dims"] | TensorDict | None = None, # noqa: F821 | ||
| **model_kwargs: Any, | ||
| ) -> Float[torch.Tensor, "B *dims"]: # noqa: F821 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: these # noqa: F821 can be dropped, see https://nvidia.slack.com/archives/C09QAS52AKV/p1768397257037909?thread_ts=1768355832.954499&cid=C09QAS52AKV
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason, I still have linter errors if I remove them. For example Float[torch.Tensor, "B *dims" gives me Undefined name 'B'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try the space in front of B, like Float[torch.Tensor, " B *dims"]. I think that's the fix jaxtyping recommends, though a bit ugly
pzharrington
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending fix for the noqa flags!
PhysicsNeMo Pull Request
Overview
This PR refactors the diffusion model preconditioners to introduce a clean, extensible architecture. The goal is to standardize how preconditioners wrap neural network models and apply the preconditioning formula, making it easier to implement new preconditioning schemes and maintain existing ones.
New
BasePreconditionerabstract base class inphysicsnemo/diffusion/preconditioners/preconditioners.pythat defines a standardized interface for wrapping diffusion models preconditioning. The base class handles the common forward pass logic while subclasses implementcompute_coefficients()to define their specific preconditioning scheme. Subclasses can optionally overridesigma()to implement custom noise schedules (time-to-noise mappings). A key improvement of this refactor is to enable dependency-injection design patterns by wrapping arbitraryphysicsnemo.Moduleinstance with a preconditioner.Four preconditioner reimplementations based on the new
BasePreconditioner:VPPreconditioner,VEPreconditioner,IDDPMPreconditioner, andEDMPreconditioner. These are cleaner, standalone versions of the existing legacy preconditioners with comprehensive docstrings including mathematical formulas for the preconditioning coefficients and noise schedules.Migrated legacy preconditioners in
legacy.pyto inherit from the new base classes. This eliminates code duplication while maintaining full backward compatibility—all existing method signatures, attributes, and behaviors are preserved. Users of the legacy API do not need to change their code.Comprehensive CI tests in
test/diffusion/test_preconditioners.pycovering:sigma()andcompute_coefficients()methodsphysicsnemo.Module.from_checkpoint()Closes 🚀[FEA]: Allow passing models as to EDMPrecond #796 .
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.