Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Gradient checkpointing in forward_intermediates() #2435

Open
collinmccarthy opened this issue Feb 5, 2025 · 4 comments
Open

[FEATURE] Gradient checkpointing in forward_intermediates() #2435

collinmccarthy opened this issue Feb 5, 2025 · 4 comments
Labels
enhancement New feature or request

Comments

@collinmccarthy
Copy link
Contributor

Is your feature request related to a problem? Please describe.
I rely on the forward_intermediates() API for object detection models, and I'm experimenting with ViT-g and would like to try gradient checkpointing.

Describe the solution you'd like
In VisionTransformer.forward_features() we have:

if self.grad_checkpointing and not torch.jit.is_scripting():
    x = checkpoint_seq(self.blocks, x)

I'm thinking something like this could work in VisionTransformer.forward_intermediates():

for i, blk in enumerate(blocks):
    if self.grad_checkpointing and not torch.jit.is_scripting():
        x = checkpoint_module(blk, x)
    else:
        x = blk(x)

I called this checkpoint_module() but I think we could just use checkpoint_seq() directly, based on the code? Either way, is this as simple as I think it would be, or am I missing something? I haven't used gradient checkpointing a lot so I'm not entirely sure.

I'm happy to submit a PR for a few models if it's as simple as calling checkpoint_seq() in forward_intermediates() as I've outlined above. I'm not sure how many models use this API and/or self.grad_checkpointing, and whether you want this to be supported in all of them.

@collinmccarthy collinmccarthy added the enhancement New feature or request label Feb 5, 2025
@collinmccarthy
Copy link
Contributor Author

I just noticed for ConvNeXt the gradient checkpointing is done within a ConvNeXt stage, which means it would work as is for forward_intermediates(). So maybe this feature request is specific to VisionTransformer (or other models whose gradient checkpointing won't work within forward_intermediates()).

@collinmccarthy
Copy link
Contributor Author

collinmccarthy commented Feb 6, 2025

Also, shouldn't this be called activation checkpointing not gradient checkpointing? Just want to make sure I'm not misunderstanding the implementation / goal here. I'm guessing the name comes from the HuggingFace trainer flag, but is a bit of a misnomer?

@rwightman
Copy link
Collaborator

@collinmccarthy you are correct on all counts, I didn't explicitly support this when I added foward_intermediates() as I was focused on getting it working / integrated and then didn't revisit.

Stage based ones that needed to push the logic into the stages should still work.

Activation checkpointing makes more sense as the name / description of what's going on, but historically it was often called gradient checkpointing so it persisted. Not going to change that now.

If you've tried the above additions and it works a PR would be welcome for any models that you happen to be working with.

Should use my checkpoint wrapper around the torch one (changes the reentrant arg)

from ._manipulate import checkpoint

...

def forward_intermediates(self, x, ...):

        ...

        for blk in self.blocks:
            if self.grad_checkpointing and not torch.jit.is_scripting():
                x = checkpoint(blk, x)
            else:
                x = blk(x)

@collinmccarthy
Copy link
Contributor Author

Thanks, all this sounds great. I'll submit a PR soon for just VisionTransformer, for now, and if I run across other models I need in the future I'll submit PRs for those and reference this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants