From 25ad65fd3564102ab4b91c41810beaac20a75084 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 19 Jan 2026 20:25:06 -0800 Subject: [PATCH 1/8] More checks in _scaled_dot_product_attention_meta --- csrc/ir/composite_nodes.cpp | 21 ++++++++----------- .../allocation_order_inference.cpp | 1 + 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index c922397ff2b..a67260a63cb 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -269,14 +269,15 @@ std::string SdpaFwdOp::toInlineString(int indent_size) const { NVF_CHECK(false, "Tensor op can not be printed inline"); } -namespace sdpa_meta { - +namespace { std::tuple -_scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) { - const int batch_size = query.size(0); - const int num_heads = query.size(1); - const int seqlen_q = query.size(2); - const int out_head_dim = value.size(-1); +scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) { + NVF_ERROR_EQ(query.dim(), 4); + NVF_ERROR_EQ(value.dim(), 4); + const auto batch_size = query.size(0); + const auto num_heads = query.size(1); + const auto seqlen_q = query.size(2); + const auto out_head_dim = value.size(-1); auto out = at::empty( {batch_size, num_heads, seqlen_q, out_head_dim}, query.options()); @@ -294,10 +295,6 @@ _scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) { return std::make_tuple(out, logsumexp, rng_state, rng_offset); } -} // namespace sdpa_meta - -namespace { - at::Tensor flattenBatchDims(at::Tensor t) { at::DimVector new_shape({-1}); auto non_batch_dims = t.sizes().slice(t.dim() - 3); @@ -369,7 +366,7 @@ std::vector SdpaFwdOp::evaluate( // 4D SDPA auto [output, log_sumexp, philox_seed, philox_offset] = [&]() { if (query.is_meta()) { - return sdpa_meta::_scaled_dot_product_attention_meta(query, value); + return scaled_dot_product_attention_meta(query, value); } if (attn_bias.defined()) { diff --git a/csrc/preseg_passes/allocation_order_inference.cpp b/csrc/preseg_passes/allocation_order_inference.cpp index af3aaf88596..23f33a4b1cb 100644 --- a/csrc/preseg_passes/allocation_order_inference.cpp +++ b/csrc/preseg_passes/allocation_order_inference.cpp @@ -363,6 +363,7 @@ class SdpaPropagator : public OptOutConstDispatch { // Don't propagate allocation to LSE because it's allocated as [B,H,S]: // https://github.com/pytorch/pytorch/blob/0db21a6b23fc6d7ccf6246dfd22f063694996144/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp#L454. } + void handle(const SdpaBwdOp* e) override { // https://github.com/pytorch/pytorch/blob/7578a0b26836116fed4daecf2f08ff75a4b2dbea/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp#L904 propagateAllocation(e->query(), e->grad_query()); From 29198e190e9a5262040056f7a5b9c01489273cb0 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 19 Jan 2026 20:31:08 -0800 Subject: [PATCH 2/8] Remove unnecessary includes --- tests/cpp/test_sdpa_node.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/cpp/test_sdpa_node.cpp b/tests/cpp/test_sdpa_node.cpp index 08b9baa250d..f54c93134f9 100644 --- a/tests/cpp/test_sdpa_node.cpp +++ b/tests/cpp/test_sdpa_node.cpp @@ -13,8 +13,6 @@ #include "ops/all_ops.h" #include "ops/utils.h" #include "optimization_pass.h" -#include "preseg_passes/allocation_order_inference.h" -#include "preseg_passes/move_split_cat.h" #include "preseg_passes/propagate_shardings.h" #include "tests/cpp/utils.h" #include "tests/cpp/validator.h" From d0840ef9f8847de8c9a469ae27bd476276f48c57 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 19 Jan 2026 22:41:40 -0800 Subject: [PATCH 3/8] Drop the support for head_dim that's not a multiple of 8 --- csrc/ir/composite_nodes.cpp | 61 ++++++++++--------------------------- 1 file changed, 16 insertions(+), 45 deletions(-) diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index a67260a63cb..e2102b8c0bd 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -432,20 +432,11 @@ std::vector SdpaFwdOp::evaluate( return std::make_tuple(out, log_sumexp, philox_seed, philox_offset); } - // Flash attention require the last dimension to be padded to 8. - auto pad_last_dim = [last_dim_size]( - at::Tensor inp, int alignment_size) -> at::Tensor { - if (last_dim_size % alignment_size == 0) { - return inp; - } - auto pad_count = alignment_size - (last_dim_size % alignment_size); - auto padded_inp = at::pad(inp, {0, pad_count}); - return padded_inp; - }; - - query = pad_last_dim(query, 8); - key = pad_last_dim(key, 8); - value = pad_last_dim(value, 8); + NVF_ERROR( + last_dim_size % 8 == 0, + "Flash attention requires the last dimension to be a multiple of 8, " + "but got: ", + last_dim_size); auto [out, @@ -466,10 +457,6 @@ std::vector SdpaFwdOp::evaluate( /*return_debug_mask=*/false, scale); - // If the inputs were padded, slice the output to restore the original size - if (out.size(-1) != last_dim_size) { - out = out.slice(-1, 0, last_dim_size); - } return std::make_tuple(out, log_sumexp, philox_seed, philox_offset); }(); @@ -587,16 +574,11 @@ std::vector SdpaBwdOp::evaluate( // Flash attention requires the last dimension to be padded to 8. // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677 const auto last_dim_size = bwd_inputs[0].size(-1); - auto pad_last_dim = [last_dim_size]( - at::Tensor inp, int alignment_size) -> at::Tensor { - if (last_dim_size % alignment_size == 0) { - return inp; - } - auto pad_count = alignment_size - (last_dim_size % alignment_size); - auto padded_inp = at::pad(inp, {0, pad_count}); - return padded_inp; - }; - + NVF_ERROR( + last_dim_size % 8 == 0, + "Flash attention requires the last dimension to be a multiple of 8, but " + "got: ", + last_dim_size); // Conmpute scale using original size of last dimension double scale = inputs.size() > 10 ? inputs.back().as() : 1.0 / std::sqrt(last_dim_size); @@ -615,11 +597,11 @@ std::vector SdpaBwdOp::evaluate( const auto philox_offset = inputs.at(9).as(); std::tie(grad_query, grad_key, grad_value) = at::_scaled_dot_product_flash_attention_backward( - /*grad_output=*/pad_last_dim(bwd_inputs[0], 8), - /*query=*/pad_last_dim(bwd_inputs[1], 8), - /*key=*/pad_last_dim(bwd_inputs[2], 8), - /*value=*/pad_last_dim(bwd_inputs[3], 8), - /*output=*/pad_last_dim(bwd_inputs[4], 8), + /*grad_output=*/bwd_inputs[0], + /*query=*/bwd_inputs[1], + /*key=*/bwd_inputs[2], + /*value=*/bwd_inputs[3], + /*output=*/bwd_inputs[4], /*logsumexp=*/bwd_inputs[5], /*cum_seq_q=*/at::Tensor(), /*cum_seq_k=*/at::Tensor(), @@ -633,14 +615,6 @@ std::vector SdpaBwdOp::evaluate( /*scale=*/scale); } - // If the inputs were padded, slice the grads to restore the original size - auto slice_last_dim = [last_dim_size](at::Tensor output) -> at::Tensor { - if (output.size(-1) != last_dim_size) { - return output.slice(-1, 0, last_dim_size); - } - return output; - }; - // Add device dimension back to outputs. if (first_dim_is_did) { grad_query = grad_query.unsqueeze(0); @@ -648,10 +622,7 @@ std::vector SdpaBwdOp::evaluate( grad_value = grad_value.unsqueeze(0); } - return { - slice_last_dim(grad_query), - slice_last_dim(grad_key), - slice_last_dim(grad_value)}; + return {grad_query, grad_key, grad_value}; } EmbeddingFwdOp::EmbeddingFwdOp( From be792ff6b79fa95a9ebacf092c87dd60b64f0fca Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 19 Jan 2026 22:50:29 -0800 Subject: [PATCH 4/8] Fix SdpaFwdOp's shape inference --- csrc/ir/composite_nodes.cpp | 63 +++++++++++++++++++----------------- tests/cpp/test_sdpa_node.cpp | 22 +++++++++++++ 2 files changed, 55 insertions(+), 30 deletions(-) diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index e2102b8c0bd..91ab3595f17 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -308,6 +308,7 @@ at::Tensor unflattenBatchDim(at::Tensor t, at::IntArrayRef batch_dims) { new_shape.append(non_batch_dims.begin(), non_batch_dims.end()); return t.view(new_shape); } + } // namespace std::vector SdpaFwdOp::evaluate( @@ -393,7 +394,7 @@ std::vector SdpaFwdOp::evaluate( "dropout_p is 0.0."); auto philox_seed = at::empty({2}, query.options().dtype(at::kUInt64)); auto philox_offset = at::empty({}, query.options().dtype(at::kUInt64)); - auto [out, log_sumexp] = at::_scaled_dot_product_attention_math( + auto [attn_out, log_sumexp] = at::_scaled_dot_product_attention_math( query, key, value, @@ -402,34 +403,11 @@ std::vector SdpaFwdOp::evaluate( is_causal, /*dropout_mask=*/std::nullopt, scale); - - // at::_scaled_dot_product_attention_math produces a contiguous attention - // output, but SdpaFwdOp requires the attention output to be in the same - // layout as the query input: - // https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362. - // Therefore, we relayout the attention output according to attn_out()'s - // allocation domain. - NVF_ERROR(out.is_contiguous()); - const std::optional out_layout = canonicalizeLayout(attn_out()); - NVF_CHECK( - out_layout.has_value(), - "Failed to canonicalize output layout of ", - attn_out()); - const std::optional> permutation = - ir_utils::computePermutation( - attn_out()->getLogicalDomain(), out_layout->allocation_domain()); NVF_ERROR( - permutation.has_value(), - "The allocation domain of a canonicalized layout of ", - attn_out(), - " is not a permutation of its logical domain."); - out = unflattenBatchDim(out, batch_dims); - out = out.permute(*permutation) - .contiguous() - .permute(ir_utils::inversePermutation(*permutation)); - out = flattenBatchDims(out); - - return std::make_tuple(out, log_sumexp, philox_seed, philox_offset); + attn_out.is_contiguous(), + "attn_out from at::_scaled_dot_product_attention_math is expected to " + "be contiguous."); + return std::make_tuple(attn_out, log_sumexp, philox_seed, philox_offset); } NVF_ERROR( @@ -439,7 +417,7 @@ std::vector SdpaFwdOp::evaluate( last_dim_size); auto - [out, + [attn_out, log_sumexp, cum_seq_q, cum_seq_k, @@ -457,7 +435,7 @@ std::vector SdpaFwdOp::evaluate( /*return_debug_mask=*/false, scale); - return std::make_tuple(out, log_sumexp, philox_seed, philox_offset); + return std::make_tuple(attn_out, log_sumexp, philox_seed, philox_offset); }(); if (batch_dims.size() > 1) { @@ -465,6 +443,31 @@ std::vector SdpaFwdOp::evaluate( log_sumexp = unflattenBatchDim(log_sumexp, batch_dims); } + { + // at::_scaled_dot_product_attention_math produces a contiguous attention + // output, but SdpaFwdOp requires the attention output to be in the same + // layout as the query input: + // https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362. + // Therefore, we relayout the attention output according to attn_out()'s + // allocation domain. + const std::optional out_layout = canonicalizeLayout(attn_out()); + NVF_CHECK( + out_layout.has_value(), + "Failed to canonicalize output layout of ", + attn_out()); + const std::optional> permutation = + ir_utils::computePermutation( + attn_out()->getLogicalDomain(), out_layout->allocation_domain()); + NVF_ERROR( + permutation.has_value(), + "The allocation domain of a canonicalized layout of ", + attn_out(), + " is not a permutation of its logical domain."); + output = output.permute(*permutation) + .contiguous() + .permute(ir_utils::inversePermutation(*permutation)); + } + // We ignore cum_seq_q/k outputs since they are undefined tensors for // non-nested tensors. We do not store query/key_seq_len since they can be // computed in non-nested tensor directly. debug_attn_mask is ignored diff --git a/tests/cpp/test_sdpa_node.cpp b/tests/cpp/test_sdpa_node.cpp index f54c93134f9..86e2383fce1 100644 --- a/tests/cpp/test_sdpa_node.cpp +++ b/tests/cpp/test_sdpa_node.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include #include "csrc/exceptions.h" @@ -225,6 +226,8 @@ void checkSdpaBwdMapping(Fusion* fusion, Expr* op) { } } +using testing::ElementsAre; + using SDPATest = NVFuserTest; TEST_F(SDPATest, NonCausalAttnConcrete) { @@ -1175,4 +1178,23 @@ TEST_F(SDPATest, ComputeAt) { validateSdpaFwdOutputs(nvf_out, aten_out, aten_out_meta); } +TEST_F(SDPATest, FlashAttentionStrideOrder) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + + at::Tensor qkv = + at::randn({n, s, h * e * 3}, at::dtype(at::kHalf).device(at::kCUDA)); + std::vector splits = + at::chunk(qkv.view({n, s, h, e * 3}), /*chunks=*/3, /*dim=*/-1); + ASSERT_EQ(splits.size(), 3); + at::Tensor q = splits.at(0).permute({0, 2, 1, 3}); + at::Tensor k = splits.at(1).permute({0, 2, 1, 3}); + at::Tensor v = splits.at(2).permute({0, 2, 1, 3}); + + at::Tensor attn_out = + std::get<0>(at::_scaled_dot_product_flash_attention(q, k, v)); + + EXPECT_THAT(attn_out.sizes(), ElementsAre(n, h, s, e)); + EXPECT_TRUE(attn_out.transpose(1, 2).is_contiguous()) << attn_out.strides(); +} + } // namespace nvfuser From b7698ac4c4196f24406d7b70faab33400ebdac94 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 19 Jan 2026 23:12:23 -0800 Subject: [PATCH 5/8] Minor --- csrc/ir/composite_nodes.cpp | 1 - tests/cpp/test_sdpa_node.cpp | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index 91ab3595f17..b19caf2b49e 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -5,7 +5,6 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include #include #include #include diff --git a/tests/cpp/test_sdpa_node.cpp b/tests/cpp/test_sdpa_node.cpp index 86e2383fce1..a88c000b3f5 100644 --- a/tests/cpp/test_sdpa_node.cpp +++ b/tests/cpp/test_sdpa_node.cpp @@ -1179,6 +1179,7 @@ TEST_F(SDPATest, ComputeAt) { } TEST_F(SDPATest, FlashAttentionStrideOrder) { + // FIXME: test backprop too NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); at::Tensor qkv = From 78d40a40c76d838d1ff673fa81de312801a4fef8 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 20 Jan 2026 07:10:40 -0800 Subject: [PATCH 6/8] Fix test_sdpa.py --- tests/python/direct/test_sdpa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/direct/test_sdpa.py b/tests/python/direct/test_sdpa.py index f9dafdfb15a..c210f567453 100644 --- a/tests/python/direct/test_sdpa.py +++ b/tests/python/direct/test_sdpa.py @@ -47,7 +47,7 @@ def fusion_func(fd: FusionDefinition) -> None: ) = fd.ops.sdpfa_fwd(q, k, v, dropout_p=None, is_causal=None, scale=None) fd.add_output(lse) - n, h, l, s, e = 1, 1, 4, 4, 2 + n, h, l, s, e = 1, 1, 4, 4, 8 inputs = [ torch.ones((n, h, l, e), dtype=torch.bfloat16, device="cuda"), torch.ones((n, h, s, e), dtype=torch.bfloat16, device="cuda"), From 460a044dc66e63d3efc5dc09c0aa7c29f923c397 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 20 Jan 2026 07:45:11 -0800 Subject: [PATCH 7/8] Fix backprop and add comments --- .../allocation_order_inference.cpp | 7 +-- tests/cpp/test_sdpa_node.cpp | 45 +++++++++++++++++-- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/csrc/preseg_passes/allocation_order_inference.cpp b/csrc/preseg_passes/allocation_order_inference.cpp index 23f33a4b1cb..9a5e768eae4 100644 --- a/csrc/preseg_passes/allocation_order_inference.cpp +++ b/csrc/preseg_passes/allocation_order_inference.cpp @@ -350,11 +350,12 @@ void inferAllocationOrder( } } -// Propagate allocation orders from an SDPA's inputs to outputs. This is -// necessary to make an SPDA's allocation domain consistent with the output +// Propagates allocation orders from an SDPA's inputs to outputs. This is +// necessary to make an SDPA's allocation domain consistent with the output // at::Tensor from expression evaluation. Currently, we call ATen to evaluate // SDPAs so matching their behavior, despite being fragile, is the best -// solution. +// solution I can think of. SdpaTest.FlashAttentionStrideOrder verifies the +// flash attention API indeed matches our expectations. class SdpaPropagator : public OptOutConstDispatch { public: void handle(const SdpaFwdOp* e) override { diff --git a/tests/cpp/test_sdpa_node.cpp b/tests/cpp/test_sdpa_node.cpp index a88c000b3f5..cbea5c3030c 100644 --- a/tests/cpp/test_sdpa_node.cpp +++ b/tests/cpp/test_sdpa_node.cpp @@ -8,7 +8,10 @@ #include #include -#include "csrc/exceptions.h" +#include +#include + +#include "exceptions.h" #include "fusion.h" #include "multidevice/device_mesh.h" #include "ops/all_ops.h" @@ -1178,8 +1181,10 @@ TEST_F(SDPATest, ComputeAt) { validateSdpaFwdOutputs(nvf_out, aten_out, aten_out_meta); } +// Verifies the flash attention API matches what +// https://github.com/NVIDIA/Fuser/blob/305907fed8ae18d1b7215dcba621b06f09d70e92/csrc/preseg_passes/allocation_order_inference.cpp#L358 +// expects. TEST_F(SDPATest, FlashAttentionStrideOrder) { - // FIXME: test backprop too NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); at::Tensor qkv = @@ -1191,11 +1196,43 @@ TEST_F(SDPATest, FlashAttentionStrideOrder) { at::Tensor k = splits.at(1).permute({0, 2, 1, 3}); at::Tensor v = splits.at(2).permute({0, 2, 1, 3}); - at::Tensor attn_out = - std::get<0>(at::_scaled_dot_product_flash_attention(q, k, v)); + auto outs = at::_scaled_dot_product_flash_attention(q, k, v); + + at::Tensor attn_out = std::get<0>(outs); + at::Tensor logsumexp = std::get<1>(outs); + at::Tensor cum_seq_q = std::get<2>(outs); + at::Tensor cum_seq_k = std::get<3>(outs); + at::SymInt max_q = std::get<4>(outs); + at::SymInt max_k = std::get<5>(outs); + at::Tensor philox_seed = std::get<6>(outs); + at::Tensor philox_offset = std::get<7>(outs); EXPECT_THAT(attn_out.sizes(), ElementsAre(n, h, s, e)); EXPECT_TRUE(attn_out.transpose(1, 2).is_contiguous()) << attn_out.strides(); + + auto [q_grad, k_grad, v_grad] = + at::_scaled_dot_product_flash_attention_backward_symint( + /*grad_output=*/attn_out, // This test merely verifies sizes and + // strides so it's fine to reuse `attn_out` + // as `grad_output` + q, + k, + v, + attn_out, + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + /*dropout_p=*/0.0, + /*is_causal=*/false, + philox_seed, + philox_offset); + + for (at::Tensor grad : {q_grad, k_grad, v_grad}) { + EXPECT_THAT(grad.sizes(), ElementsAre(n, h, s, e)); + EXPECT_TRUE(grad.transpose(1, 2).is_contiguous()) << grad.strides(); + } } } // namespace nvfuser From d2480844c41cdf795e27dc38381cee66ee19a6b7 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 20 Jan 2026 08:38:37 -0800 Subject: [PATCH 8/8] relayoutByTensorView as a function --- csrc/ir/composite_nodes.cpp | 54 ++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index b19caf2b49e..e3b87b48b83 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -308,6 +308,32 @@ at::Tensor unflattenBatchDim(at::Tensor t, at::IntArrayRef batch_dims) { return t.view(new_shape); } +// at::_scaled_dot_product_attention_math and +// scaled_dot_product_attention_meta produce a contiguous attention output, +// but SdpaFwdOp requires the attention output to be in the same layout as +// the query input: +// https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362. +// Therefore, we relayout the attention output according to attn_out()'s +// allocation domain. +// +// at::_scaled_dot_product_flash_attention goes through this code as well +// but the `.contiguous()` will be a no-op. +at::Tensor relayoutByTensorView(at::Tensor t, TensorView* tv) { + const std::optional layout = canonicalizeLayout(tv); + NVF_CHECK(layout.has_value(), "Failed to canonicalize output layout of ", tv); + const std::optional> permutation = + ir_utils::computePermutation( + tv->getLogicalDomain(), layout->allocation_domain()); + NVF_ERROR( + permutation.has_value(), + "The allocation domain of a canonicalized layout of ", + tv, + " is not a permutation of its logical domain."); + return t.permute(*permutation) + .contiguous() + .permute(ir_utils::inversePermutation(*permutation)); +} + } // namespace std::vector SdpaFwdOp::evaluate( @@ -363,7 +389,8 @@ std::vector SdpaFwdOp::evaluate( attn_bias = flattenBatchDims(attn_bias.contiguous()); } - // 4D SDPA + // 4D SDPA. `output`'s strides don't necessarily match `attn_out()`'s + // requirements, which will be fixed in the next section. auto [output, log_sumexp, philox_seed, philox_offset] = [&]() { if (query.is_meta()) { return scaled_dot_product_attention_meta(query, value); @@ -442,30 +469,7 @@ std::vector SdpaFwdOp::evaluate( log_sumexp = unflattenBatchDim(log_sumexp, batch_dims); } - { - // at::_scaled_dot_product_attention_math produces a contiguous attention - // output, but SdpaFwdOp requires the attention output to be in the same - // layout as the query input: - // https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362. - // Therefore, we relayout the attention output according to attn_out()'s - // allocation domain. - const std::optional out_layout = canonicalizeLayout(attn_out()); - NVF_CHECK( - out_layout.has_value(), - "Failed to canonicalize output layout of ", - attn_out()); - const std::optional> permutation = - ir_utils::computePermutation( - attn_out()->getLogicalDomain(), out_layout->allocation_domain()); - NVF_ERROR( - permutation.has_value(), - "The allocation domain of a canonicalized layout of ", - attn_out(), - " is not a permutation of its logical domain."); - output = output.permute(*permutation) - .contiguous() - .permute(ir_utils::inversePermutation(*permutation)); - } + output = relayoutByTensorView(output, attn_out()); // We ignore cum_seq_q/k outputs since they are undefined tensors for // non-nested tensors. We do not store query/key_seq_len since they can be