Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 63 additions & 89 deletions csrc/ir/composite_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <algorithm>
#include <iterator>
#include <limits>
#include <sstream>
Expand Down Expand Up @@ -269,14 +268,15 @@ std::string SdpaFwdOp::toInlineString(int indent_size) const {
NVF_CHECK(false, "Tensor op can not be printed inline");
}

namespace sdpa_meta {

namespace {
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) {
const int batch_size = query.size(0);
const int num_heads = query.size(1);
const int seqlen_q = query.size(2);
const int out_head_dim = value.size(-1);
scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) {
NVF_ERROR_EQ(query.dim(), 4);
NVF_ERROR_EQ(value.dim(), 4);
const auto batch_size = query.size(0);
const auto num_heads = query.size(1);
const auto seqlen_q = query.size(2);
const auto out_head_dim = value.size(-1);

auto out = at::empty(
{batch_size, num_heads, seqlen_q, out_head_dim}, query.options());
Expand All @@ -294,10 +294,6 @@ _scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) {
return std::make_tuple(out, logsumexp, rng_state, rng_offset);
}

} // namespace sdpa_meta

namespace {

at::Tensor flattenBatchDims(at::Tensor t) {
at::DimVector new_shape({-1});
auto non_batch_dims = t.sizes().slice(t.dim() - 3);
Expand All @@ -311,6 +307,33 @@ at::Tensor unflattenBatchDim(at::Tensor t, at::IntArrayRef batch_dims) {
new_shape.append(non_batch_dims.begin(), non_batch_dims.end());
return t.view(new_shape);
}

// at::_scaled_dot_product_attention_math and
// scaled_dot_product_attention_meta produce a contiguous attention output,
// but SdpaFwdOp requires the attention output to be in the same layout as
// the query input:
// https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362.
// Therefore, we relayout the attention output according to attn_out()'s
// allocation domain.
//
// at::_scaled_dot_product_flash_attention goes through this code as well
// but the `.contiguous()` will be a no-op.
at::Tensor relayoutByTensorView(at::Tensor t, TensorView* tv) {
const std::optional<Layout> layout = canonicalizeLayout(tv);
NVF_CHECK(layout.has_value(), "Failed to canonicalize output layout of ", tv);
const std::optional<std::vector<int64_t>> permutation =
ir_utils::computePermutation(
tv->getLogicalDomain(), layout->allocation_domain());
NVF_ERROR(
permutation.has_value(),
"The allocation domain of a canonicalized layout of ",
tv,
" is not a permutation of its logical domain.");
return t.permute(*permutation)
.contiguous()
.permute(ir_utils::inversePermutation(*permutation));
}

} // namespace

std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
Expand Down Expand Up @@ -366,10 +389,11 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
attn_bias = flattenBatchDims(attn_bias.contiguous());
}

// 4D SDPA
// 4D SDPA. `output`'s strides don't necessarily match `attn_out()`'s
// requirements, which will be fixed in the next section.
auto [output, log_sumexp, philox_seed, philox_offset] = [&]() {
if (query.is_meta()) {
return sdpa_meta::_scaled_dot_product_attention_meta(query, value);
return scaled_dot_product_attention_meta(query, value);
}

if (attn_bias.defined()) {
Expand All @@ -396,7 +420,7 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
"dropout_p is 0.0.");
auto philox_seed = at::empty({2}, query.options().dtype(at::kUInt64));
auto philox_offset = at::empty({}, query.options().dtype(at::kUInt64));
auto [out, log_sumexp] = at::_scaled_dot_product_attention_math(
auto [attn_out, log_sumexp] = at::_scaled_dot_product_attention_math(
query,
key,
value,
Expand All @@ -405,53 +429,21 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
is_causal,
/*dropout_mask=*/std::nullopt,
scale);

// at::_scaled_dot_product_attention_math produces a contiguous attention
// output, but SdpaFwdOp requires the attention output to be in the same
// layout as the query input:
// https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362.
// Therefore, we relayout the attention output according to attn_out()'s
// allocation domain.
NVF_ERROR(out.is_contiguous());
const std::optional<Layout> out_layout = canonicalizeLayout(attn_out());
NVF_CHECK(
out_layout.has_value(),
"Failed to canonicalize output layout of ",
attn_out());
const std::optional<std::vector<int64_t>> permutation =
ir_utils::computePermutation(
attn_out()->getLogicalDomain(), out_layout->allocation_domain());
NVF_ERROR(
permutation.has_value(),
"The allocation domain of a canonicalized layout of ",
attn_out(),
" is not a permutation of its logical domain.");
out = unflattenBatchDim(out, batch_dims);
out = out.permute(*permutation)
.contiguous()
.permute(ir_utils::inversePermutation(*permutation));
out = flattenBatchDims(out);

return std::make_tuple(out, log_sumexp, philox_seed, philox_offset);
attn_out.is_contiguous(),
"attn_out from at::_scaled_dot_product_attention_math is expected to "
"be contiguous.");
return std::make_tuple(attn_out, log_sumexp, philox_seed, philox_offset);
}

// Flash attention require the last dimension to be padded to 8.
auto pad_last_dim = [last_dim_size](
at::Tensor inp, int alignment_size) -> at::Tensor {
if (last_dim_size % alignment_size == 0) {
return inp;
}
auto pad_count = alignment_size - (last_dim_size % alignment_size);
auto padded_inp = at::pad(inp, {0, pad_count});
return padded_inp;
};

query = pad_last_dim(query, 8);
key = pad_last_dim(key, 8);
value = pad_last_dim(value, 8);
NVF_ERROR(
last_dim_size % 8 == 0,
"Flash attention requires the last dimension to be a multiple of 8, "
"but got: ",
last_dim_size);

auto
[out,
[attn_out,
log_sumexp,
cum_seq_q,
cum_seq_k,
Expand All @@ -469,18 +461,16 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
/*return_debug_mask=*/false,
scale);

// If the inputs were padded, slice the output to restore the original size
if (out.size(-1) != last_dim_size) {
out = out.slice(-1, 0, last_dim_size);
}
return std::make_tuple(out, log_sumexp, philox_seed, philox_offset);
return std::make_tuple(attn_out, log_sumexp, philox_seed, philox_offset);
}();

if (batch_dims.size() > 1) {
output = unflattenBatchDim(output, batch_dims);
log_sumexp = unflattenBatchDim(log_sumexp, batch_dims);
}

output = relayoutByTensorView(output, attn_out());

// We ignore cum_seq_q/k outputs since they are undefined tensors for
// non-nested tensors. We do not store query/key_seq_len since they can be
// computed in non-nested tensor directly. debug_attn_mask is ignored
Expand Down Expand Up @@ -590,16 +580,11 @@ std::vector<PolymorphicValue> SdpaBwdOp::evaluate(
// Flash attention requires the last dimension to be padded to 8.
// https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677
const auto last_dim_size = bwd_inputs[0].size(-1);
auto pad_last_dim = [last_dim_size](
at::Tensor inp, int alignment_size) -> at::Tensor {
if (last_dim_size % alignment_size == 0) {
return inp;
}
auto pad_count = alignment_size - (last_dim_size % alignment_size);
auto padded_inp = at::pad(inp, {0, pad_count});
return padded_inp;
};

NVF_ERROR(
last_dim_size % 8 == 0,
"Flash attention requires the last dimension to be a multiple of 8, but "
"got: ",
last_dim_size);
// Conmpute scale using original size of last dimension
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Conmpute scale using original size of last dimension

double scale = inputs.size() > 10 ? inputs.back().as<double>()
: 1.0 / std::sqrt(last_dim_size);
Expand All @@ -618,11 +603,11 @@ std::vector<PolymorphicValue> SdpaBwdOp::evaluate(
const auto philox_offset = inputs.at(9).as<at::Tensor>();
std::tie(grad_query, grad_key, grad_value) =
at::_scaled_dot_product_flash_attention_backward(
/*grad_output=*/pad_last_dim(bwd_inputs[0], 8),
/*query=*/pad_last_dim(bwd_inputs[1], 8),
/*key=*/pad_last_dim(bwd_inputs[2], 8),
/*value=*/pad_last_dim(bwd_inputs[3], 8),
/*output=*/pad_last_dim(bwd_inputs[4], 8),
/*grad_output=*/bwd_inputs[0],
/*query=*/bwd_inputs[1],
/*key=*/bwd_inputs[2],
/*value=*/bwd_inputs[3],
/*output=*/bwd_inputs[4],
/*logsumexp=*/bwd_inputs[5],
/*cum_seq_q=*/at::Tensor(),
/*cum_seq_k=*/at::Tensor(),
Expand All @@ -636,25 +621,14 @@ std::vector<PolymorphicValue> SdpaBwdOp::evaluate(
/*scale=*/scale);
}

// If the inputs were padded, slice the grads to restore the original size
auto slice_last_dim = [last_dim_size](at::Tensor output) -> at::Tensor {
if (output.size(-1) != last_dim_size) {
return output.slice(-1, 0, last_dim_size);
}
return output;
};

// Add device dimension back to outputs.
if (first_dim_is_did) {
grad_query = grad_query.unsqueeze(0);
grad_key = grad_key.unsqueeze(0);
grad_value = grad_value.unsqueeze(0);
}

return {
slice_last_dim(grad_query),
slice_last_dim(grad_key),
slice_last_dim(grad_value)};
return {grad_query, grad_key, grad_value};
}

EmbeddingFwdOp::EmbeddingFwdOp(
Expand Down
8 changes: 5 additions & 3 deletions csrc/preseg_passes/allocation_order_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,12 @@ void inferAllocationOrder(
}
}

// Propagate allocation orders from an SDPA's inputs to outputs. This is
// necessary to make an SPDA's allocation domain consistent with the output
// Propagates allocation orders from an SDPA's inputs to outputs. This is
// necessary to make an SDPA's allocation domain consistent with the output
// at::Tensor from expression evaluation. Currently, we call ATen to evaluate
// SDPAs so matching their behavior, despite being fragile, is the best
// solution.
// solution I can think of. SdpaTest.FlashAttentionStrideOrder verifies the
// flash attention API indeed matches our expectations.
class SdpaPropagator : public OptOutConstDispatch {
public:
void handle(const SdpaFwdOp* e) override {
Expand All @@ -363,6 +364,7 @@ class SdpaPropagator : public OptOutConstDispatch {
// Don't propagate allocation to LSE because it's allocated as [B,H,S]:
// https://github.com/pytorch/pytorch/blob/0db21a6b23fc6d7ccf6246dfd22f063694996144/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp#L454.
}

void handle(const SdpaBwdOp* e) override {
// https://github.com/pytorch/pytorch/blob/7578a0b26836116fed4daecf2f08ff75a4b2dbea/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp#L904
propagateAllocation(e->query(), e->grad_query());
Expand Down
64 changes: 61 additions & 3 deletions tests/cpp/test_sdpa_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "csrc/exceptions.h"
#include <ATen/ops/_scaled_dot_product_flash_attention.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_backward.h>

#include "exceptions.h"
#include "fusion.h"
#include "multidevice/device_mesh.h"
#include "ops/all_ops.h"
#include "ops/utils.h"
#include "optimization_pass.h"
#include "preseg_passes/allocation_order_inference.h"
#include "preseg_passes/move_split_cat.h"
#include "preseg_passes/propagate_shardings.h"
#include "tests/cpp/utils.h"
#include "tests/cpp/validator.h"
Expand Down Expand Up @@ -227,6 +229,8 @@ void checkSdpaBwdMapping(Fusion* fusion, Expr* op) {
}
}

using testing::ElementsAre;

using SDPATest = NVFuserTest;

TEST_F(SDPATest, NonCausalAttnConcrete) {
Expand Down Expand Up @@ -1177,4 +1181,58 @@ TEST_F(SDPATest, ComputeAt) {
validateSdpaFwdOutputs(nvf_out, aten_out, aten_out_meta);
}

// Verifies the flash attention API matches what
// https://github.com/NVIDIA/Fuser/blob/305907fed8ae18d1b7215dcba621b06f09d70e92/csrc/preseg_passes/allocation_order_inference.cpp#L358
// expects.
TEST_F(SDPATest, FlashAttentionStrideOrder) {
NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);

at::Tensor qkv =
at::randn({n, s, h * e * 3}, at::dtype(at::kHalf).device(at::kCUDA));
std::vector<at::Tensor> splits =
at::chunk(qkv.view({n, s, h, e * 3}), /*chunks=*/3, /*dim=*/-1);
ASSERT_EQ(splits.size(), 3);
at::Tensor q = splits.at(0).permute({0, 2, 1, 3});
at::Tensor k = splits.at(1).permute({0, 2, 1, 3});
at::Tensor v = splits.at(2).permute({0, 2, 1, 3});

auto outs = at::_scaled_dot_product_flash_attention(q, k, v);

at::Tensor attn_out = std::get<0>(outs);
at::Tensor logsumexp = std::get<1>(outs);
at::Tensor cum_seq_q = std::get<2>(outs);
at::Tensor cum_seq_k = std::get<3>(outs);
at::SymInt max_q = std::get<4>(outs);
at::SymInt max_k = std::get<5>(outs);
at::Tensor philox_seed = std::get<6>(outs);
at::Tensor philox_offset = std::get<7>(outs);

EXPECT_THAT(attn_out.sizes(), ElementsAre(n, h, s, e));
EXPECT_TRUE(attn_out.transpose(1, 2).is_contiguous()) << attn_out.strides();

auto [q_grad, k_grad, v_grad] =
at::_scaled_dot_product_flash_attention_backward_symint(
/*grad_output=*/attn_out, // This test merely verifies sizes and
// strides so it's fine to reuse `attn_out`
// as `grad_output`
q,
k,
v,
attn_out,
logsumexp,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
/*dropout_p=*/0.0,
/*is_causal=*/false,
philox_seed,
philox_offset);

for (at::Tensor grad : {q_grad, k_grad, v_grad}) {
EXPECT_THAT(grad.sizes(), ElementsAre(n, h, s, e));
EXPECT_TRUE(grad.transpose(1, 2).is_contiguous()) << grad.strides();
}
}

} // namespace nvfuser
2 changes: 1 addition & 1 deletion tests/python/direct/test_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def fusion_func(fd: FusionDefinition) -> None:
) = fd.ops.sdpfa_fwd(q, k, v, dropout_p=None, is_causal=None, scale=None)
fd.add_output(lse)

n, h, l, s, e = 1, 1, 4, 4, 2
n, h, l, s, e = 1, 1, 4, 4, 8
inputs = [
torch.ones((n, h, l, e), dtype=torch.bfloat16, device="cuda"),
torch.ones((n, h, s, e), dtype=torch.bfloat16, device="cuda"),
Expand Down