diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index 40fa1770bbb..1615cdbb857 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -5,7 +5,6 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include #include #include #include @@ -269,14 +268,15 @@ std::string SdpaFwdOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -namespace sdpa_meta { - +namespace { std::tuple -_scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) { - const int batch_size = query.size(0); - const int num_heads = query.size(1); - const int seqlen_q = query.size(2); - const int out_head_dim = value.size(-1); +scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) { + NVF_ERROR_EQ(query.dim(), 4); + NVF_ERROR_EQ(value.dim(), 4); + const auto batch_size = query.size(0); + const auto num_heads = query.size(1); + const auto seqlen_q = query.size(2); + const auto out_head_dim = value.size(-1); auto out = at::empty( {batch_size, num_heads, seqlen_q, out_head_dim}, query.options()); @@ -294,10 +294,6 @@ _scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) { return std::make_tuple(out, logsumexp, rng_state, rng_offset); } -} // namespace sdpa_meta - -namespace { - at::Tensor flattenBatchDims(at::Tensor t) { at::DimVector new_shape({-1}); auto non_batch_dims = t.sizes().slice(t.dim() - 3); @@ -311,6 +307,33 @@ at::Tensor unflattenBatchDim(at::Tensor t, at::IntArrayRef batch_dims) { new_shape.append(non_batch_dims.begin(), non_batch_dims.end()); return t.view(new_shape); } + +// at::_scaled_dot_product_attention_math and +// scaled_dot_product_attention_meta produce a contiguous attention output, +// but SdpaFwdOp requires the attention output to be in the same layout as +// the query input: +// https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362. +// Therefore, we relayout the attention output according to attn_out()'s +// allocation domain. +// +// at::_scaled_dot_product_flash_attention goes through this code as well +// but the `.contiguous()` will be a no-op. +at::Tensor relayoutByTensorView(at::Tensor t, TensorView* tv) { + const std::optional layout = canonicalizeLayout(tv); + NVF_CHECK(layout.has_value(), "Failed to canonicalize output layout of ", tv); + const std::optional> permutation = + ir_utils::computePermutation( + tv->getLogicalDomain(), layout->allocation_domain()); + NVF_ERROR( + permutation.has_value(), + "The allocation domain of a canonicalized layout of ", + tv, + " is not a permutation of its logical domain."); + return t.permute(*permutation) + .contiguous() + .permute(ir_utils::inversePermutation(*permutation)); +} + } // namespace std::vector SdpaFwdOp::evaluate( @@ -366,10 +389,11 @@ std::vector SdpaFwdOp::evaluate( attn_bias = flattenBatchDims(attn_bias.contiguous()); } - // 4D SDPA + // 4D SDPA. `output`'s strides don't necessarily match `attn_out()`'s + // requirements, which will be fixed in the next section. auto [output, logsumexp, philox_seed, philox_offset] = [&]() { if (query.is_meta()) { - return sdpa_meta::_scaled_dot_product_attention_meta(query, value); + return scaled_dot_product_attention_meta(query, value); } if (attn_bias.defined()) { @@ -396,7 +420,7 @@ std::vector SdpaFwdOp::evaluate( "dropout_p is 0.0."); auto philox_seed = at::empty({2}, query.options().dtype(at::kUInt64)); auto philox_offset = at::empty({}, query.options().dtype(at::kUInt64)); - auto [out, logsumexp] = at::_scaled_dot_product_attention_math( + auto [attn_out, logsumexp] = at::_scaled_dot_product_attention_math( query, key, value, @@ -405,34 +429,11 @@ std::vector SdpaFwdOp::evaluate( is_causal, /*dropout_mask=*/std::nullopt, scale); - - // at::_scaled_dot_product_attention_math produces a contiguous attention - // output, but SdpaFwdOp requires the attention output to be in the same - // layout as the query input: - // https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362. - // Therefore, we relayout the attention output according to attn_out()'s - // allocation domain. - NVF_ERROR(out.is_contiguous()); - const std::optional out_layout = canonicalizeLayout(attn_out()); - NVF_CHECK( - out_layout.has_value(), - "Failed to canonicalize output layout of ", - attn_out()); - const std::optional> permutation = - ir_utils::computePermutation( - attn_out()->getLogicalDomain(), out_layout->allocation_domain()); NVF_ERROR( - permutation.has_value(), - "The allocation domain of a canonicalized layout of ", - attn_out(), - " is not a permutation of its logical domain."); - out = unflattenBatchDim(out, batch_dims); - out = out.permute(*permutation) - .contiguous() - .permute(ir_utils::inversePermutation(*permutation)); - out = flattenBatchDims(out); - - return std::make_tuple(out, logsumexp, philox_seed, philox_offset); + attn_out.is_contiguous(), + "attn_out from at::_scaled_dot_product_attention_math is expected to " + "be contiguous."); + return std::make_tuple(attn_out, logsumexp, philox_seed, philox_offset); } NVF_ERROR( @@ -442,7 +443,7 @@ std::vector SdpaFwdOp::evaluate( last_dim_size); auto - [out, + [attn_out, logsumexp, cum_seq_q, cum_seq_k, @@ -460,7 +461,7 @@ std::vector SdpaFwdOp::evaluate( /*return_debug_mask=*/false, scale); - return std::make_tuple(out, logsumexp, philox_seed, philox_offset); + return std::make_tuple(attn_out, logsumexp, philox_seed, philox_offset); }(); if (batch_dims.size() > 1) { @@ -468,6 +469,8 @@ std::vector SdpaFwdOp::evaluate( logsumexp = unflattenBatchDim(logsumexp, batch_dims); } + output = relayoutByTensorView(output, attn_out()); + // We ignore cum_seq_q/k outputs since they are undefined tensors for // non-nested tensors. We do not store query/key_seq_len since they can be // computed in non-nested tensor directly. debug_attn_mask is ignored diff --git a/csrc/preseg_passes/allocation_order_inference.cpp b/csrc/preseg_passes/allocation_order_inference.cpp index e7aea4e90fa..c730f70fea8 100644 --- a/csrc/preseg_passes/allocation_order_inference.cpp +++ b/csrc/preseg_passes/allocation_order_inference.cpp @@ -351,11 +351,12 @@ void inferAllocationOrder( } } -// Propagate allocation orders from an SDPA's inputs to outputs. This is -// necessary to make an SPDA's allocation domain consistent with the output +// Propagates allocation orders from an SDPA's inputs to outputs. This is +// necessary to make an SDPA's allocation domain consistent with the output // at::Tensor from expression evaluation. Currently, we call ATen to evaluate // SDPAs so matching their behavior, despite being fragile, is the best -// solution. +// solution I can think of. SdpaTest.FlashAttentionStrideOrder verifies the +// flash attention API indeed matches our expectations. class SdpaPropagator : public OptOutConstDispatch { public: void handle(const SdpaFwdOp* e) override { @@ -364,6 +365,7 @@ class SdpaPropagator : public OptOutConstDispatch { // Don't propagate allocation to LSE because it's allocated as [B,H,S]: // https://github.com/pytorch/pytorch/blob/0db21a6b23fc6d7ccf6246dfd22f063694996144/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp#L454. } + void handle(const SdpaBwdOp* e) override { // https://github.com/pytorch/pytorch/blob/7578a0b26836116fed4daecf2f08ff75a4b2dbea/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp#L904 propagateAllocation(e->query(), e->grad_query()); diff --git a/tests/cpp/test_sdpa_node.cpp b/tests/cpp/test_sdpa_node.cpp index 3a28f02c1af..88661742243 100644 --- a/tests/cpp/test_sdpa_node.cpp +++ b/tests/cpp/test_sdpa_node.cpp @@ -5,16 +5,18 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include -#include "csrc/exceptions.h" +#include +#include + +#include "exceptions.h" #include "fusion.h" #include "multidevice/device_mesh.h" #include "ops/all_ops.h" #include "ops/utils.h" #include "optimization_pass.h" -#include "preseg_passes/allocation_order_inference.h" -#include "preseg_passes/move_split_cat.h" #include "preseg_passes/propagate_shardings.h" #include "tests/cpp/utils.h" #include "validator_utils.h" @@ -227,6 +229,8 @@ void checkSdpaBwdMapping(Fusion* fusion, Expr* op) { } } +using testing::ElementsAre; + using SDPATest = NVFuserTest; TEST_F(SDPATest, NonCausalAttnConcrete) { @@ -1177,4 +1181,58 @@ TEST_F(SDPATest, ComputeAt) { validateSdpaFwdOutputs(nvf_out, aten_out, aten_out_meta); } +// Verifies the flash attention API matches what +// https://github.com/NVIDIA/Fuser/blob/305907fed8ae18d1b7215dcba621b06f09d70e92/csrc/preseg_passes/allocation_order_inference.cpp#L358 +// expects. +TEST_F(SDPATest, FlashAttentionStrideOrder) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + + at::Tensor qkv = + at::randn({n, s, h * e * 3}, at::dtype(at::kHalf).device(at::kCUDA)); + std::vector splits = + at::chunk(qkv.view({n, s, h, e * 3}), /*chunks=*/3, /*dim=*/-1); + ASSERT_EQ(splits.size(), 3); + at::Tensor q = splits.at(0).permute({0, 2, 1, 3}); + at::Tensor k = splits.at(1).permute({0, 2, 1, 3}); + at::Tensor v = splits.at(2).permute({0, 2, 1, 3}); + + auto outs = at::_scaled_dot_product_flash_attention(q, k, v); + + at::Tensor attn_out = std::get<0>(outs); + at::Tensor logsumexp = std::get<1>(outs); + at::Tensor cum_seq_q = std::get<2>(outs); + at::Tensor cum_seq_k = std::get<3>(outs); + at::SymInt max_q = std::get<4>(outs); + at::SymInt max_k = std::get<5>(outs); + at::Tensor philox_seed = std::get<6>(outs); + at::Tensor philox_offset = std::get<7>(outs); + + EXPECT_THAT(attn_out.sizes(), ElementsAre(n, h, s, e)); + EXPECT_TRUE(attn_out.transpose(1, 2).is_contiguous()) << attn_out.strides(); + + auto [q_grad, k_grad, v_grad] = + at::_scaled_dot_product_flash_attention_backward_symint( + /*grad_output=*/attn_out, // This test merely verifies sizes and + // strides so it's fine to reuse `attn_out` + // as `grad_output` + q, + k, + v, + attn_out, + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + /*dropout_p=*/0.0, + /*is_causal=*/false, + philox_seed, + philox_offset); + + for (at::Tensor grad : {q_grad, k_grad, v_grad}) { + EXPECT_THAT(grad.sizes(), ElementsAre(n, h, s, e)); + EXPECT_TRUE(grad.transpose(1, 2).is_contiguous()) << grad.strides(); + } +} + } // namespace nvfuser