You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Transformer inference (i.e. with no teacher forcing) is slow. In practice I think people typically implement some kind of caching so that at each timestep, we do not need to recompute the embeddings and attentions between all previously decoded timesteps.
I have a quick and dirty implementation of this in an experimental fork, where I basically tell the decoder layer to only get the attention from the most recently decoded target, and all other representations are concatenated on. There are probably other tricks that I can find by e.g. inspecting some huggingface transformers inference code.
I propose adding an option to the transformer encoder decoders to use caching, wherein a CacheTransformerDecoder module is used.
This is a low-priority TODO since we do validation with accuracy, rather than loss, and accuracy can be reliably predicted with teacher forcing if the targets are provided.
The text was updated successfully, but these errors were encountered:
I don't think so. I think e.g. the attention of t wrt t-1 will always be the same. So caching it is just a memory v. runtime tradeoff. The default is to pass in the full sequence of decoded tokens at every timestep and recompute all of the attentions each time. I will think that through more and test before opening a PR though.
treat it as a separate architecture to keep things simple.
I don't think so. I think e.g. the attention of t wrt t-1 will always be the same. So caching it is just a memory v. runtime tradeoff. The default is to pass in the full sequence of decoded tokens at every timestep and recompute all of the attentions each time. I will think that through more and test before opening a PR though.
Transformer inference (i.e. with no teacher forcing) is slow. In practice I think people typically implement some kind of caching so that at each timestep, we do not need to recompute the embeddings and attentions between all previously decoded timesteps.
I have a quick and dirty implementation of this in an experimental fork, where I basically tell the decoder layer to only get the attention from the most recently decoded target, and all other representations are concatenated on. There are probably other tricks that I can find by e.g. inspecting some huggingface transformers inference code.
I propose adding an option to the transformer encoder decoders to use caching, wherein a CacheTransformerDecoder module is used.
This is a low-priority TODO since we do validation with accuracy, rather than loss, and accuracy can be reliably predicted with teacher forcing if the targets are provided.
The text was updated successfully, but these errors were encountered: