Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching for transformer inference #154

Open
Adamits opened this issue Nov 16, 2023 · 3 comments
Open

Add caching for transformer inference #154

Adamits opened this issue Nov 16, 2023 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@Adamits
Copy link
Collaborator

Adamits commented Nov 16, 2023

Transformer inference (i.e. with no teacher forcing) is slow. In practice I think people typically implement some kind of caching so that at each timestep, we do not need to recompute the embeddings and attentions between all previously decoded timesteps.

I have a quick and dirty implementation of this in an experimental fork, where I basically tell the decoder layer to only get the attention from the most recently decoded target, and all other representations are concatenated on. There are probably other tricks that I can find by e.g. inspecting some huggingface transformers inference code.

I propose adding an option to the transformer encoder decoders to use caching, wherein a CacheTransformerDecoder module is used.

This is a low-priority TODO since we do validation with accuracy, rather than loss, and accuracy can be reliably predicted with teacher forcing if the targets are provided.

@Adamits Adamits self-assigned this Nov 16, 2023
@Adamits Adamits added the enhancement New feature or request label Nov 16, 2023
@kylebgorman
Copy link
Contributor

So this is an approximation/hack, right? I'm fine with it, and maybe we could treat it as a separate architecture to keep things simple.

@Adamits
Copy link
Collaborator Author

Adamits commented Nov 16, 2023

So this is an approximation/hack, right?

I don't think so. I think e.g. the attention of t wrt t-1 will always be the same. So caching it is just a memory v. runtime tradeoff. The default is to pass in the full sequence of decoded tokens at every timestep and recompute all of the attentions each time. I will think that through more and test before opening a PR though.

treat it as a separate architecture to keep things simple.

Sure we can do it that way.

@kylebgorman
Copy link
Contributor

So this is an approximation/hack, right?

I don't think so. I think e.g. the attention of t wrt t-1 will always be the same. So caching it is just a memory v. runtime tradeoff. The default is to pass in the full sequence of decoded tokens at every timestep and recompute all of the attentions each time. I will think that through more and test before opening a PR though.

Thanks for clarification. All the better, then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants