-
Notifications
You must be signed in to change notification settings - Fork 573
mad #4983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mad #4983
Conversation
Caution Review failedThe pull request is closed. 📝 WalkthroughWalkthroughAdds MAD-related regularization to DPA3 descriptor and energy loss (new EnergyStdLossMAD), wires it into training and config parsing, and significantly refactors RepFlows, network MLPs, and graph-index utilities. Several modules receive extensive documentation/comments. An example reduces training steps. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant Trainer as training.get_loss
participant Loss as EnergyStdLossMAD
participant Model as DPA3/RepFlows Model
participant Desc as DPA3 Descriptor
User->>Trainer: loss_type="ener_mad", loss_params
Trainer->>Loss: construct(**loss_params)
Note right of Loss: mad_reg_coeff stored
loop Training Step
Trainer->>Model: forward(input)
Model->>Desc: forward(...)
alt enable_mad
Desc->>Desc: compute MAD (store last_mad_gap)
else disable
Desc->>Desc: last_mad_gap=None
end
Trainer->>Loss: forward(outputs, model, labels, natoms, lr)
Loss->>Loss: base energy/force/virial losses
alt mad_reg_coeff > 0 and last_mad_gap available
Loss->>Desc: read last_mad_gap
Loss->>Loss: mad_reg_loss = coeff * |MAD - 1|
Loss->>Trainer: total_loss += mad_reg_loss
else no MAD reg
Loss-->>Trainer: total_loss (base)
end
end
sequenceDiagram
autonumber
participant Model as DescrptBlockRepflows.forward
participant Env as prod_env_mat
participant Utils as get_graph_index
participant Layer as RepFlowLayer[xN]
participant Border as torch.ops.deepmd.border_op
Model->>Env: build edge env (dmatrix, diff, sw)
Model->>Env: build angle env (a_dist_mask, a_diff, a_sw)
Model->>Utils: get_graph_index(nlist, masks, a_masks, nall, use_loc_mapping)
alt parallel_mode
Model->>Border: exchange/gather features
end
loop layers
Model->>Layer: forward(node_ebd_ext, edge_ebd, h2, angle_ebd, indices, masks, sw)
Layer-->>Model: updated embeddings
end
Model->>Layer: _cal_hg / _cal_hg_dynamic
Model-->>Caller: node_ebd, edge_ebd, h2, rot_mat, sw
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (12)
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Chores