-
Notifications
You must be signed in to change notification settings - Fork 75
SdpaFwdOp::evaluate computes meta tensors respecting allocation domains. #5848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 8 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
25ad65f
More checks in _scaled_dot_product_attention_meta
wujingyue 29198e1
Remove unnecessary includes
wujingyue d0840ef
Drop the support for head_dim that's not a multiple of 8
wujingyue be792ff
Fix SdpaFwdOp's shape inference
wujingyue b7698ac
Minor
wujingyue 78d40a4
Fix test_sdpa.py
wujingyue 460a044
Fix backprop and add comments
wujingyue d248084
relayoutByTensorView as a function
wujingyue 5e114fd
Merge branch 'main' into wjy/sdpa-meta
wujingyue File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -5,7 +5,6 @@ | |||
| * SPDX-License-Identifier: BSD-3-Clause | ||||
| */ | ||||
| // clang-format on | ||||
| #include <algorithm> | ||||
| #include <iterator> | ||||
| #include <limits> | ||||
| #include <sstream> | ||||
|
|
@@ -269,14 +268,15 @@ std::string SdpaFwdOp::toInlineString(int indent_size) const { | |||
| NVF_CHECK(false, "Tensor op can not be printed inline"); | ||||
| } | ||||
|
|
||||
| namespace sdpa_meta { | ||||
|
|
||||
| namespace { | ||||
| std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> | ||||
| _scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) { | ||||
| const int batch_size = query.size(0); | ||||
| const int num_heads = query.size(1); | ||||
| const int seqlen_q = query.size(2); | ||||
| const int out_head_dim = value.size(-1); | ||||
| scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) { | ||||
| NVF_ERROR_EQ(query.dim(), 4); | ||||
| NVF_ERROR_EQ(value.dim(), 4); | ||||
| const auto batch_size = query.size(0); | ||||
| const auto num_heads = query.size(1); | ||||
| const auto seqlen_q = query.size(2); | ||||
| const auto out_head_dim = value.size(-1); | ||||
|
|
||||
| auto out = at::empty( | ||||
| {batch_size, num_heads, seqlen_q, out_head_dim}, query.options()); | ||||
|
|
@@ -294,10 +294,6 @@ _scaled_dot_product_attention_meta(at::Tensor query, at::Tensor value) { | |||
| return std::make_tuple(out, logsumexp, rng_state, rng_offset); | ||||
| } | ||||
|
|
||||
| } // namespace sdpa_meta | ||||
|
|
||||
| namespace { | ||||
|
|
||||
| at::Tensor flattenBatchDims(at::Tensor t) { | ||||
| at::DimVector new_shape({-1}); | ||||
| auto non_batch_dims = t.sizes().slice(t.dim() - 3); | ||||
|
|
@@ -311,6 +307,33 @@ at::Tensor unflattenBatchDim(at::Tensor t, at::IntArrayRef batch_dims) { | |||
| new_shape.append(non_batch_dims.begin(), non_batch_dims.end()); | ||||
| return t.view(new_shape); | ||||
| } | ||||
|
|
||||
| // at::_scaled_dot_product_attention_math and | ||||
| // scaled_dot_product_attention_meta produce a contiguous attention output, | ||||
| // but SdpaFwdOp requires the attention output to be in the same layout as | ||||
| // the query input: | ||||
| // https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362. | ||||
| // Therefore, we relayout the attention output according to attn_out()'s | ||||
| // allocation domain. | ||||
| // | ||||
| // at::_scaled_dot_product_flash_attention goes through this code as well | ||||
| // but the `.contiguous()` will be a no-op. | ||||
| at::Tensor relayoutByTensorView(at::Tensor t, TensorView* tv) { | ||||
| const std::optional<Layout> layout = canonicalizeLayout(tv); | ||||
| NVF_CHECK(layout.has_value(), "Failed to canonicalize output layout of ", tv); | ||||
| const std::optional<std::vector<int64_t>> permutation = | ||||
| ir_utils::computePermutation( | ||||
| tv->getLogicalDomain(), layout->allocation_domain()); | ||||
| NVF_ERROR( | ||||
| permutation.has_value(), | ||||
| "The allocation domain of a canonicalized layout of ", | ||||
| tv, | ||||
| " is not a permutation of its logical domain."); | ||||
| return t.permute(*permutation) | ||||
| .contiguous() | ||||
| .permute(ir_utils::inversePermutation(*permutation)); | ||||
| } | ||||
|
|
||||
| } // namespace | ||||
|
|
||||
| std::vector<PolymorphicValue> SdpaFwdOp::evaluate( | ||||
|
|
@@ -366,10 +389,11 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate( | |||
| attn_bias = flattenBatchDims(attn_bias.contiguous()); | ||||
| } | ||||
|
|
||||
| // 4D SDPA | ||||
| // 4D SDPA. `output`'s strides don't necessarily match `attn_out()`'s | ||||
| // requirements, which will be fixed in the next section. | ||||
| auto [output, log_sumexp, philox_seed, philox_offset] = [&]() { | ||||
| if (query.is_meta()) { | ||||
| return sdpa_meta::_scaled_dot_product_attention_meta(query, value); | ||||
| return scaled_dot_product_attention_meta(query, value); | ||||
| } | ||||
|
|
||||
| if (attn_bias.defined()) { | ||||
|
|
@@ -396,7 +420,7 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate( | |||
| "dropout_p is 0.0."); | ||||
| auto philox_seed = at::empty({2}, query.options().dtype(at::kUInt64)); | ||||
| auto philox_offset = at::empty({}, query.options().dtype(at::kUInt64)); | ||||
| auto [out, log_sumexp] = at::_scaled_dot_product_attention_math( | ||||
| auto [attn_out, log_sumexp] = at::_scaled_dot_product_attention_math( | ||||
| query, | ||||
| key, | ||||
| value, | ||||
|
|
@@ -405,53 +429,21 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate( | |||
| is_causal, | ||||
| /*dropout_mask=*/std::nullopt, | ||||
| scale); | ||||
|
|
||||
| // at::_scaled_dot_product_attention_math produces a contiguous attention | ||||
| // output, but SdpaFwdOp requires the attention output to be in the same | ||||
| // layout as the query input: | ||||
| // https://github.com/NVIDIA/Fuser/blob/fe23484180f47f8ac27a3527fdbcef2ff1be2a66/csrc/preseg_passes/allocation_order_inference.cpp#L361-L362. | ||||
| // Therefore, we relayout the attention output according to attn_out()'s | ||||
| // allocation domain. | ||||
| NVF_ERROR(out.is_contiguous()); | ||||
| const std::optional<Layout> out_layout = canonicalizeLayout(attn_out()); | ||||
| NVF_CHECK( | ||||
| out_layout.has_value(), | ||||
| "Failed to canonicalize output layout of ", | ||||
| attn_out()); | ||||
| const std::optional<std::vector<int64_t>> permutation = | ||||
| ir_utils::computePermutation( | ||||
| attn_out()->getLogicalDomain(), out_layout->allocation_domain()); | ||||
| NVF_ERROR( | ||||
| permutation.has_value(), | ||||
| "The allocation domain of a canonicalized layout of ", | ||||
| attn_out(), | ||||
| " is not a permutation of its logical domain."); | ||||
| out = unflattenBatchDim(out, batch_dims); | ||||
| out = out.permute(*permutation) | ||||
| .contiguous() | ||||
| .permute(ir_utils::inversePermutation(*permutation)); | ||||
| out = flattenBatchDims(out); | ||||
|
|
||||
| return std::make_tuple(out, log_sumexp, philox_seed, philox_offset); | ||||
| attn_out.is_contiguous(), | ||||
| "attn_out from at::_scaled_dot_product_attention_math is expected to " | ||||
| "be contiguous."); | ||||
| return std::make_tuple(attn_out, log_sumexp, philox_seed, philox_offset); | ||||
| } | ||||
|
|
||||
| // Flash attention require the last dimension to be padded to 8. | ||||
| auto pad_last_dim = [last_dim_size]( | ||||
| at::Tensor inp, int alignment_size) -> at::Tensor { | ||||
| if (last_dim_size % alignment_size == 0) { | ||||
| return inp; | ||||
| } | ||||
| auto pad_count = alignment_size - (last_dim_size % alignment_size); | ||||
| auto padded_inp = at::pad(inp, {0, pad_count}); | ||||
| return padded_inp; | ||||
| }; | ||||
|
|
||||
| query = pad_last_dim(query, 8); | ||||
| key = pad_last_dim(key, 8); | ||||
| value = pad_last_dim(value, 8); | ||||
| NVF_ERROR( | ||||
| last_dim_size % 8 == 0, | ||||
| "Flash attention requires the last dimension to be a multiple of 8, " | ||||
| "but got: ", | ||||
| last_dim_size); | ||||
|
|
||||
| auto | ||||
| [out, | ||||
| [attn_out, | ||||
| log_sumexp, | ||||
| cum_seq_q, | ||||
| cum_seq_k, | ||||
|
|
@@ -469,18 +461,16 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate( | |||
| /*return_debug_mask=*/false, | ||||
| scale); | ||||
|
|
||||
| // If the inputs were padded, slice the output to restore the original size | ||||
| if (out.size(-1) != last_dim_size) { | ||||
| out = out.slice(-1, 0, last_dim_size); | ||||
| } | ||||
| return std::make_tuple(out, log_sumexp, philox_seed, philox_offset); | ||||
| return std::make_tuple(attn_out, log_sumexp, philox_seed, philox_offset); | ||||
| }(); | ||||
|
|
||||
| if (batch_dims.size() > 1) { | ||||
| output = unflattenBatchDim(output, batch_dims); | ||||
| log_sumexp = unflattenBatchDim(log_sumexp, batch_dims); | ||||
| } | ||||
|
|
||||
| output = relayoutByTensorView(output, attn_out()); | ||||
|
|
||||
| // We ignore cum_seq_q/k outputs since they are undefined tensors for | ||||
| // non-nested tensors. We do not store query/key_seq_len since they can be | ||||
| // computed in non-nested tensor directly. debug_attn_mask is ignored | ||||
|
|
@@ -590,16 +580,11 @@ std::vector<PolymorphicValue> SdpaBwdOp::evaluate( | |||
| // Flash attention requires the last dimension to be padded to 8. | ||||
| // https://github.com/pytorch/pytorch/blob/c27882ffa8c1c7e4cf8ebc6c2f879e5b6c8814ad/aten/src/ATen/native/transformers/attention.cpp#L675-L677 | ||||
| const auto last_dim_size = bwd_inputs[0].size(-1); | ||||
| auto pad_last_dim = [last_dim_size]( | ||||
| at::Tensor inp, int alignment_size) -> at::Tensor { | ||||
| if (last_dim_size % alignment_size == 0) { | ||||
| return inp; | ||||
| } | ||||
| auto pad_count = alignment_size - (last_dim_size % alignment_size); | ||||
| auto padded_inp = at::pad(inp, {0, pad_count}); | ||||
| return padded_inp; | ||||
| }; | ||||
|
|
||||
| NVF_ERROR( | ||||
| last_dim_size % 8 == 0, | ||||
wujingyue marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
| "Flash attention requires the last dimension to be a multiple of 8, but " | ||||
| "got: ", | ||||
| last_dim_size); | ||||
| // Conmpute scale using original size of last dimension | ||||
wujingyue marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
| // Conmpute scale using original size of last dimension |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.