|
| 1 | +#include "infinicore/ops/gla_attention.hpp" |
| 2 | + |
| 3 | +#include "infinicore/ops/matmul.hpp" |
| 4 | +#include "infinicore/ops/causal_softmax.hpp" |
| 5 | +#include "../../utils.hpp" |
| 6 | + |
| 7 | +namespace infinicore::op { |
| 8 | + |
| 9 | +Tensor gla_attention(const Tensor &q, |
| 10 | + const Tensor &k_total, |
| 11 | + const Tensor &v_total, |
| 12 | + float scale, |
| 13 | + bool causal) { |
| 14 | + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(q, k_total, v_total); |
| 15 | + |
| 16 | + const auto &q_shape = q->shape(); // [B, n_q, S_q, D] |
| 17 | + const auto &k_shape = k_total->shape(); // [B, n_kv, S_kv, D] |
| 18 | + const auto &v_shape = v_total->shape(); // [B, n_kv, S_kv, D] |
| 19 | + |
| 20 | + INFINICORE_ASSERT(q_shape.size() == 4); |
| 21 | + INFINICORE_ASSERT(k_shape.size() == 4); |
| 22 | + INFINICORE_ASSERT(v_shape.size() == 4); |
| 23 | + INFINICORE_ASSERT(q_shape[0] == k_shape[0] && k_shape[0] == v_shape[0]); // B |
| 24 | + INFINICORE_ASSERT(q_shape[3] == k_shape[3] && k_shape[3] == v_shape[3]); // D |
| 25 | + INFINICORE_ASSERT(k_shape[1] == v_shape[1] && k_shape[2] == v_shape[2]); // n_kv, S_kv |
| 26 | + |
| 27 | + const size_t B = q_shape[0]; |
| 28 | + const size_t n_q = q_shape[1]; |
| 29 | + const size_t S_q = q_shape[2]; |
| 30 | + const size_t D = q_shape[3]; |
| 31 | + const size_t n_kv = k_shape[1]; |
| 32 | + const size_t S_kv = k_shape[2]; |
| 33 | + |
| 34 | + INFINICORE_ASSERT(n_q % n_kv == 0); |
| 35 | + const size_t ngroup = n_q / n_kv; |
| 36 | + |
| 37 | + // Reshape to grouped GQA layout: |
| 38 | + // Q: [B * n_kv, ngroup * S_q, D] |
| 39 | + // K: [B * n_kv, S_kv, D] |
| 40 | + // V: [B * n_kv, S_kv, D] |
| 41 | + auto Q = q->view({B * n_kv, ngroup, S_q, D}) |
| 42 | + ->view({B * n_kv, ngroup * S_q, D}); |
| 43 | + auto K = k_total->view({B * n_kv, S_kv, D}); |
| 44 | + auto V = v_total->view({B * n_kv, S_kv, D}); |
| 45 | + |
| 46 | + auto Kt = K->permute({0, 2, 1}); // [B * n_kv, D, S_kv] |
| 47 | + auto attn_weight = infinicore::op::matmul(Q, Kt, scale); // [B * n_kv, ngroup * S_q, S_kv] |
| 48 | + |
| 49 | + if (causal) { |
| 50 | + auto attn_weight_softmax = |
| 51 | + attn_weight->view({B * n_q, S_q, S_kv}); // [B * n_q, S_q, S_kv] |
| 52 | + infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax); |
| 53 | + } |
| 54 | + |
| 55 | + auto out = infinicore::op::matmul(attn_weight, V); // [B * n_kv, ngroup * S_q, D] |
| 56 | + auto out_view = |
| 57 | + out->view({B, n_kv, ngroup, S_q, D}) |
| 58 | + ->view({B, n_q, S_q, D}); // merge kv,group back into n_q |
| 59 | + |
| 60 | + return out_view; |
| 61 | +} |
| 62 | + |
| 63 | +} // namespace infinicore::op |
| 64 | + |
0 commit comments