-
Notifications
You must be signed in to change notification settings - Fork 540
Update transolver to comply with model standards #1316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR updates the Transolver model to comply with PhysicsNeMo model implementation standards by adding comprehensive documentation, type annotations, and validation logic.
Major changes:
- Added complete NumPy-style docstrings with proper sections (Parameters, Forward, Outputs, Examples) across all model classes and functions
- Added jaxtyping type annotations for all tensor arguments following MOD-006
- Added input validation with
torch.compiler.is_compiling()guards in main forward methods following MOD-005 - Added high-level comments explaining complex tensor operations following MOD-003k
- Updated
pyproject.tomlto ignore F722 (allows jaxtyping syntax) - Changed docstring prefixes from
"""tor"""for LaTeX compatibility following MOD-003b
Critical issues found:
MLPandTransolver_blockclasses inherit fromnn.Moduleinstead ofphysicsnemo.Module, violating MOD-001. These classes should be updated to inherit fromphysicsnemo.Moduleto ensure access to serialization, versioning, and registry features.- Missing input validation in
MLP.forward()andTransolver_block.forward()methods (MOD-005 requirement)
Positive aspects:
- Excellent documentation quality with clear examples and cross-references
- Proper use of LaTeX math notation for tensor shapes
- Good high-level comments in complex tensor operations
- Consistent formatting and structure across all files
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| physicsnemo/models/transolver/transolver.py | 3/5 | Added comprehensive docstrings with jaxtyping annotations, input validation, and high-level comments. Found critical issue: MLP and Transolver_block inherit from nn.Module instead of physicsnemo.Module (violates MOD-001). |
| physicsnemo/models/transolver/Physics_Attention.py | 4/5 | Added comprehensive docstrings with proper sections, jaxtyping annotations, and input validation. All classes correctly inherit from nn.Module (appropriate for reusable layers per MOD-000a). |
| physicsnemo/models/transolver/Embedding.py | 4/5 | Added complete docstrings with proper NumPy-style sections, jaxtyping annotations, LaTeX math notation for tensor shapes, and Examples sections. All classes correctly inherit from nn.Module. |
| pyproject.toml | 5/5 | Added F722 to ruff ignore list (allows jaxtyping syntax) and removed trailing whitespace. Changes are appropriate for supporting jaxtyping annotations. |
| } | ||
|
|
||
|
|
||
| class MLP(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
violates MOD-001 - MLP inherits from nn.Module but should inherit from physicsnemo.Module
per MOD-001, all model classes must inherit from physicsnemo.Module instead of torch.nn.Module. this ensures access to physicsnemo features like serialization, versioning, and registry support.
since MLP is a reusable layer and part of the transolver model hierarchy, it should follow this standard.
| return x | ||
|
|
||
|
|
||
| class Transolver_block(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
violates MOD-001 - Transolver_block inherits from nn.Module but should inherit from physicsnemo.Module
per MOD-001, all model classes must inherit from physicsnemo.Module instead of torch.nn.Module. this ensures access to physicsnemo features like serialization, versioning, and registry support.
| def forward( | ||
| self, x: Float[torch.Tensor, "... d_in"] | ||
| ) -> Float[torch.Tensor, "... d_out"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing validation for MLP.forward() - per MOD-005, forward methods must validate tensor shapes at the beginning
add validation wrapped in if not torch.compiler.is_compiling(): check to validate that input tensor has expected shape
| def forward( | |
| self, x: Float[torch.Tensor, "... d_in"] | |
| ) -> Float[torch.Tensor, "... d_out"]: | |
| def forward( | |
| self, x: Float[torch.Tensor, "... d_in"] | |
| ) -> Float[torch.Tensor, "... d_out"]: | |
| r""" | |
| Forward pass of the MLP. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| Input tensor of shape :math:`(*, D_{in})`. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Output tensor of shape :math:`(*, D_{out})`. | |
| """ | |
| ### Input validation | |
| if not torch.compiler.is_compiling(): | |
| if x.shape[-1] != self.n_input: | |
| raise ValueError( | |
| f"Expected input with {self.n_input} features, " | |
| f"got {x.shape[-1]} features (shape: {tuple(x.shape)})" | |
| ) | |
| # Project input to hidden dimension |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| def forward( | ||
| self, fx: Float[torch.Tensor, "batch tokens hidden"] | ||
| ) -> Float[torch.Tensor, "batch tokens out"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing validation for Transolver_block.forward() - per MOD-005, forward methods must validate tensor shapes
add validation wrapped in if not torch.compiler.is_compiling(): to check input shape matches expected (B, N, hidden_dim)
| def forward( | ||
| self, | ||
| coordinates: Float[torch.Tensor, "batch seq"], | ||
| device: torch.device, | ||
| ) -> Float[torch.Tensor, "batch seq dim"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing validation for RotaryEmbedding.forward() - per MOD-005, forward methods should validate tensor shapes
consider adding validation to check that coordinates has expected 2D shape (B, N)
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| def forward( | ||
| self, x: Float[torch.Tensor, "batch seq dim"] | ||
| ) -> Float[torch.Tensor, "batch seq dim"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing validation for PositionalEncoding.forward() - per MOD-005, forward methods should validate tensor shapes
consider adding validation to check input shape is 3D (B, N, D) and that D matches d_model
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
PhysicsNeMo Pull Request
Description
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.