-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] Gradient checkpointing in forward_intermediates()
#2435
Comments
I just noticed for ConvNeXt the gradient checkpointing is done within a ConvNeXt stage, which means it would work as is for |
Also, shouldn't this be called activation checkpointing not gradient checkpointing? Just want to make sure I'm not misunderstanding the implementation / goal here. I'm guessing the name comes from the HuggingFace trainer flag, but is a bit of a misnomer? |
@collinmccarthy you are correct on all counts, I didn't explicitly support this when I added foward_intermediates() as I was focused on getting it working / integrated and then didn't revisit. Stage based ones that needed to push the logic into the stages should still work. Activation checkpointing makes more sense as the name / description of what's going on, but historically it was often called gradient checkpointing so it persisted. Not going to change that now. If you've tried the above additions and it works a PR would be welcome for any models that you happen to be working with. Should use my checkpoint wrapper around the torch one (changes the reentrant arg)
|
Thanks, all this sounds great. I'll submit a PR soon for just |
Is your feature request related to a problem? Please describe.
I rely on the
forward_intermediates()
API for object detection models, and I'm experimenting with ViT-g and would like to try gradient checkpointing.Describe the solution you'd like
In
VisionTransformer.forward_features()
we have:I'm thinking something like this could work in
VisionTransformer.forward_intermediates()
:I called this
checkpoint_module()
but I think we could just usecheckpoint_seq()
directly, based on the code? Either way, is this as simple as I think it would be, or am I missing something? I haven't used gradient checkpointing a lot so I'm not entirely sure.I'm happy to submit a PR for a few models if it's as simple as calling
checkpoint_seq()
inforward_intermediates()
as I've outlined above. I'm not sure how many models use this API and/orself.grad_checkpointing
, and whether you want this to be supported in all of them.The text was updated successfully, but these errors were encountered: