Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 47 additions & 44 deletions csrc/ir/composite_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <algorithm>
#include <iterator>
#include <limits>
#include <sstream>
Expand Down Expand Up @@ -269,14 +268,15 @@ std::string SdpaFwdOp::toInlineString(int indent_size) const {
NVF_CHECK(false, "Tensor op can not be printed inline");
}

namespace sdpa_meta {

namespace {
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) {
const int batch_size = query.size(0);
const int num_heads = query.size(1);
const int seqlen_q = query.size(2);
const int out_head_dim = value.size(-1);
scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) {
NVF_ERROR_EQ(query.dim(), 4);
NVF_ERROR_EQ(value.dim(), 4);
const auto batch_size = query.size(0);
const auto num_heads = query.size(1);
const auto seqlen_q = query.size(2);
const auto out_head_dim = value.size(-1);

auto out = at::empty(
{batch_size, num_heads, seqlen_q, out_head_dim}, query.options());
Expand All @@ -294,10 +294,6 @@ _scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) {
return std::make_tuple(out, logsumexp, rng_state, rng_offset);
}

} // namespace sdpa_meta

namespace {

at::Tensor flattenBatchDims(at::Tensor t) {
at::DimVector new_shape({-1});
auto non_batch_dims = t.sizes().slice(t.dim() - 3);
Expand All @@ -311,6 +307,33 @@ at::Tensor unflattenBatchDim(at::Tensor t, at::IntArrayRef batch_dims) {
new_shape.append(non_batch_dims.begin(), non_batch_dims.end());
return t.view(new_shape);
}

// at::_scaled_dot_product_attention_math and
// scaled_dot_product_attention_meta produce a contiguous attention output,
// but SdpaFwdOp requires the attention output to be in the same layout as
// the query input:
// https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362.
// Therefore, we relayout the attention output according to attn_out()'s
// allocation domain.
//
// at::_scaled_dot_product_flash_attention goes through this code as well
// but the `.contiguous()` will be a no-op.
at::Tensor relayoutByTensorView(at::Tensor t, TensorView* tv) {
const std::optional<Layout> layout = canonicalizeLayout(tv);
NVF_CHECK(layout.has_value(), "Failed to canonicalize output layout of ", tv);
const std::optional<std::vector<int64_t>> permutation =
ir_utils::computePermutation(
tv->getLogicalDomain(), layout->allocation_domain());
NVF_ERROR(
permutation.has_value(),
"The allocation domain of a canonicalized layout of ",
tv,
" is not a permutation of its logical domain.");
return t.permute(*permutation)
.contiguous()
.permute(ir_utils::inversePermutation(*permutation));
}

} // namespace

std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
Expand Down Expand Up @@ -366,10 +389,11 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
attn_bias = flattenBatchDims(attn_bias.contiguous());
}

// 4D SDPA
// 4D SDPA. `output`'s strides don't necessarily match `attn_out()`'s
// requirements, which will be fixed in the next section.
auto [output, logsumexp, philox_seed, philox_offset] = [&]() {
if (query.is_meta()) {
return sdpa_meta::_scaled_dot_product_attention_meta(query, value);
return scaled_dot_product_attention_meta(query, value);
}

if (attn_bias.defined()) {
Expand All @@ -396,7 +420,7 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
"dropout_p is 0.0.");
auto philox_seed = at::empty({2}, query.options().dtype(at::kUInt64));
auto philox_offset = at::empty({}, query.options().dtype(at::kUInt64));
auto [out, logsumexp] = at::_scaled_dot_product_attention_math(
auto [attn_out, logsumexp] = at::_scaled_dot_product_attention_math(
query,
key,
value,
Expand All @@ -405,34 +429,11 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
is_causal,
/*dropout_mask=*/std::nullopt,
scale);

// at::_scaled_dot_product_attention_math produces a contiguous attention
// output, but SdpaFwdOp requires the attention output to be in the same
// layout as the query input:
// https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362.
// Therefore, we relayout the attention output according to attn_out()'s
// allocation domain.
NVF_ERROR(out.is_contiguous());
const std::optional<Layout> out_layout = canonicalizeLayout(attn_out());
NVF_CHECK(
out_layout.has_value(),
"Failed to canonicalize output layout of ",
attn_out());
const std::optional<std::vector<int64_t>> permutation =
ir_utils::computePermutation(
attn_out()->getLogicalDomain(), out_layout->allocation_domain());
NVF_ERROR(
permutation.has_value(),
"The allocation domain of a canonicalized layout of ",
attn_out(),
" is not a permutation of its logical domain.");
out = unflattenBatchDim(out, batch_dims);
out = out.permute(*permutation)
.contiguous()
.permute(ir_utils::inversePermutation(*permutation));
out = flattenBatchDims(out);

return std::make_tuple(out, logsumexp, philox_seed, philox_offset);
attn_out.is_contiguous(),
"attn_out from at::_scaled_dot_product_attention_math is expected to "
"be contiguous.");
return std::make_tuple(attn_out, logsumexp, philox_seed, philox_offset);
}

NVF_ERROR(
Expand All @@ -442,7 +443,7 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
last_dim_size);

auto
[out,
[attn_out,
logsumexp,
cum_seq_q,
cum_seq_k,
Expand All @@ -460,14 +461,16 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
/*return_debug_mask=*/false,
scale);

return std::make_tuple(out, logsumexp, philox_seed, philox_offset);
return std::make_tuple(attn_out, logsumexp, philox_seed, philox_offset);
}();

if (batch_dims.size() > 1) {
output = unflattenBatchDim(output, batch_dims);
logsumexp = unflattenBatchDim(logsumexp, batch_dims);
}

output = relayoutByTensorView(output, attn_out());

// We ignore cum_seq_q/k outputs since they are undefined tensors for
// non-nested tensors. We do not store query/key_seq_len since they can be
// computed in non-nested tensor directly. debug_attn_mask is ignored
Expand Down
8 changes: 5 additions & 3 deletions csrc/preseg_passes/allocation_order_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,12 @@ void inferAllocationOrder(
}
}

// Propagate allocation orders from an SDPA's inputs to outputs. This is
// necessary to make an SPDA's allocation domain consistent with the output
// Propagates allocation orders from an SDPA's inputs to outputs. This is
// necessary to make an SDPA's allocation domain consistent with the output
// at::Tensor from expression evaluation. Currently, we call ATen to evaluate
// SDPAs so matching their behavior, despite being fragile, is the best
// solution.
// solution I can think of. SdpaTest.FlashAttentionStrideOrder verifies the
// flash attention API indeed matches our expectations.
class SdpaPropagator : public OptOutConstDispatch {
public:
void handle(const SdpaFwdOp* e) override {
Expand All @@ -364,6 +365,7 @@ class SdpaPropagator : public OptOutConstDispatch {
// Don't propagate allocation to LSE because it's allocated as [B,H,S]:
// https://github.com/pytorch/pytorch/blob/0db21a6b23fc6d7ccf6246dfd22f063694996144/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp#L454.
}

void handle(const SdpaBwdOp* e) override {
// https://github.com/pytorch/pytorch/blob/7578a0b26836116fed4daecf2f08ff75a4b2dbea/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp#L904
propagateAllocation(e->query(), e->grad_query());
Expand Down
64 changes: 61 additions & 3 deletions tests/cpp/test_sdpa_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "csrc/exceptions.h"
#include <ATen/ops/_scaled_dot_product_flash_attention.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_backward.h>

#include "exceptions.h"
#include "fusion.h"
#include "multidevice/device_mesh.h"
#include "ops/all_ops.h"
#include "ops/utils.h"
#include "optimization_pass.h"
#include "preseg_passes/allocation_order_inference.h"
#include "preseg_passes/move_split_cat.h"
#include "preseg_passes/propagate_shardings.h"
#include "tests/cpp/utils.h"
#include "validator_utils.h"
Expand Down Expand Up @@ -227,6 +229,8 @@ void checkSdpaBwdMapping(Fusion* fusion, Expr* op) {
}
}

using testing::ElementsAre;

using SDPATest = NVFuserTest;

TEST_F(SDPATest, NonCausalAttnConcrete) {
Expand Down Expand Up @@ -1177,4 +1181,58 @@ TEST_F(SDPATest, ComputeAt) {
validateSdpaFwdOutputs(nvf_out, aten_out, aten_out_meta);
}

// Verifies the flash attention API matches what
// https://github.com/NVIDIA/Fuser/blob/305907fed8ae18d1b7215dcba621b06f09d70e92/csrc/preseg_passes/allocation_order_inference.cpp#L358
// expects.
TEST_F(SDPATest, FlashAttentionStrideOrder) {
NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);

at::Tensor qkv =
at::randn({n, s, h * e * 3}, at::dtype(at::kHalf).device(at::kCUDA));
std::vector<at::Tensor> splits =
at::chunk(qkv.view({n, s, h, e * 3}), /*chunks=*/3, /*dim=*/-1);
ASSERT_EQ(splits.size(), 3);
at::Tensor q = splits.at(0).permute({0, 2, 1, 3});
at::Tensor k = splits.at(1).permute({0, 2, 1, 3});
at::Tensor v = splits.at(2).permute({0, 2, 1, 3});

auto outs = at::_scaled_dot_product_flash_attention(q, k, v);

at::Tensor attn_out = std::get<0>(outs);
at::Tensor logsumexp = std::get<1>(outs);
at::Tensor cum_seq_q = std::get<2>(outs);
at::Tensor cum_seq_k = std::get<3>(outs);
at::SymInt max_q = std::get<4>(outs);
at::SymInt max_k = std::get<5>(outs);
at::Tensor philox_seed = std::get<6>(outs);
at::Tensor philox_offset = std::get<7>(outs);

EXPECT_THAT(attn_out.sizes(), ElementsAre(n, h, s, e));
EXPECT_TRUE(attn_out.transpose(1, 2).is_contiguous()) << attn_out.strides();

auto [q_grad, k_grad, v_grad] =
at::_scaled_dot_product_flash_attention_backward_symint(
/*grad_output=*/attn_out, // This test merely verifies sizes and
// strides so it's fine to reuse `attn_out`
// as `grad_output`
q,
k,
v,
attn_out,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
/*dropout_p=*/0.0,
/*is_causal=*/false,
philox_seed,
philox_offset);

for (at::Tensor grad : {q_grad, k_grad, v_grad}) {
EXPECT_THAT(grad.sizes(), ElementsAre(n, h, s, e));
EXPECT_TRUE(grad.transpose(1, 2).is_contiguous()) << grad.strides();
}
}

} // namespace nvfuser