From 33fd8d73b6c1b762ad7242c0aebffdce9bba8983 Mon Sep 17 00:00:00 2001 From: Graeme Nail Date: Wed, 24 Aug 2022 14:33:13 +0100 Subject: [PATCH] Check shape on transformer cache --- src/models/transformer.h | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/models/transformer.h b/src/models/transformer.h index 2d9ced33c..a149ae2c4 100755 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -28,7 +28,7 @@ class Transformer : public EncoderOrDecoderBase { protected: using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_; - std::unordered_map cache_; // caching transformation of the encoder that should not be created again + std::unordered_map> cache_; // caching transformation of the encoder that should not be created again mutable/*lazy*/ std::vector sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings // attention weights produced by step() @@ -279,10 +279,10 @@ class Transformer : public EncoderOrDecoderBase { // Caching transformation of the encoder that should not be created again. // @TODO: set this automatically by memoizing encoder context and // memoization propagation (short-term) - if (cache // if caching - && cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen - && cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change - kh = cache_[prefix + "_keys"]; // then return cached tensor + if (cache // if caching + && cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen + && cache_[prefix + "_keys"].first == keys->shape()) { // and the underlying element size did not change + kh = cache_[prefix + "_keys"].second; // then return cached tensor } else { auto Wk = graph_->param(prefix + "_Wk", {dimModel, dimModel}, inits::glorotUniform()); @@ -290,21 +290,21 @@ class Transformer : public EncoderOrDecoderBase { kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim] kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim] - cache_[prefix + "_keys"] = kh; + cache_[prefix + "_keys"] = std::make_pair(keys->shape(), kh); } Expr vh; - if (cache - && cache_.count(prefix + "_values") > 0 - && cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) { - vh = cache_[prefix + "_values"]; + if (cache + && cache_.count(prefix + "_values") > 0 + && cache_[prefix + "_values"].first == values->shape()) { + vh = cache_[prefix + "_values"].second; } else { auto Wv = graph_->param(prefix + "_Wv", {dimModel, dimModel}, inits::glorotUniform()); auto bv = graph_->param(prefix + "_bv", {1, dimModel}, inits::zeros()); vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim] vh = SplitHeads(vh, dimHeads); - cache_[prefix + "_values"] = vh; + cache_[prefix + "_values"] = std::make_pair(values->shape(), vh); } int dimBeam = q->shape()[-4];