Skip to content

Commit f56cb41

Browse files
drisspgpytorchmergebot
authored andcommitted
Fix calls to sizes to enable dynamic shapes with sdpa (pytorch#96674)
Fixes part of pytorch#96414 Replaces any calls to sizes, with sym_sizes. Still seeing an error with the repro script: ``` Bash Exception raised from sizes_default at /scratch/drisspg/work/pytorch/c10/core/TensorImpl.h:635 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x7d (0x7f697f4a141d in /scratch/drisspg/work/pytorch/torch/lib/libc10.so) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0xdd (0x7f697f49fbcd in /scratch/drisspg/work/pytorch/torch/lib/libc10.so) frame #2: c10::TensorImpl::sizes_custom() const + 0x95 (0x7f697f4824c5 in /scratch/drisspg/work/pytorch/torch/lib/libc10.so) frame #3: at::native::empty_like(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, c10::optional<c10::MemoryFormat>) + 0x92c (0x7f69809d18ac in /scratch/drisspg/work/pytorch/torch/lib/libtorch_cpu.so) frame #4: <unknown function> + 0x23f5ce7 (0x7f698193bce7 in /scratch/drisspg/work/pytorch/torch/lib/libtorch_cpu.so) ``` still trying to track down this empty call from the looks of it, might be coming from at::layer_norm? the BT from lldb is 221 frames however, so lots of noise Pull Request resolved: pytorch#96674 Approved by: https://github.com/ezyang
1 parent 218eeac commit f56cb41

File tree

2 files changed

+33
-33
lines changed

2 files changed

+33
-33
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.h

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <functional>
1818
#include <cmath>
19+
#include <c10/core/SymInt.h>
1920
#include <c10/util/string_view.h>
2021

2122
namespace sdp {
@@ -55,27 +56,28 @@ inline std::array<SDPBackend, num_backends> priority_order(sdp_params params) {
5556
// FlashAttention parallelizes across "batch_size * num_heads"
5657
// MemEff parallelizes across "batch_size * num_heads * num_queries" and can
5758
// be more efficient. batch_size, q_len, num_heads, k = inp.query.shape
58-
if (params.query.is_nested() || params.key.is_nested() || params.value.is_nested()) {
59+
if (params.query.is_nested() || params.key.is_nested() ||
60+
params.value.is_nested()) {
5961
// See check_for_nested_inputs for details
6062
return {
6163
SDPBackend::efficient_attention,
6264
SDPBackend::flash_attention,
6365
SDPBackend::math};
6466
}
65-
const auto sizes = params.query.sizes();
6667
if (params.query.dim() != 4) {
6768
return default_order;
6869
}
69-
const auto batch_size{sizes[0]}, num_heads{sizes[1]}, query_lengths{sizes[2]},
70-
head_dim{sizes[3]};
70+
const auto batch_size{params.query.sym_size(0)},
71+
num_heads{params.query.sym_size(1)},
72+
query_lengths{params.query.sym_size(2)},
73+
head_dim{params.query.sym_size(3)};
7174
if (batch_size > 0) {
72-
const int64_t threads_flash = batch_size * num_heads;
73-
const int64_t threads_cutlass =
74-
threads_flash * (int64_t)std::floor(query_lengths / 64);
75-
bool more_threads_cutlass =
76-
(int64_t)std::floor(threads_cutlass / 2) >= threads_flash;
75+
const auto threads_flash = batch_size * num_heads;
76+
const auto threads_cutlass =
77+
threads_flash * (query_lengths / c10::SymInt(64));
78+
bool more_threads_cutlass = (threads_cutlass / 2) >= threads_flash;
7779
bool small_threads_flash = threads_flash < 60;
78-
bool large_head_dim = std::max(head_dim, params.key.sizes()[3]) == 128;
80+
bool large_head_dim = head_dim.max(params.key.sym_size(3)) == 128;
7981
if ((small_threads_flash && more_threads_cutlass) || large_head_dim) {
8082
return {
8183
SDPBackend::efficient_attention,
@@ -130,9 +132,9 @@ inline bool check_for_nested_inputs(sdp_params params){
130132
return false;
131133
}
132134

133-
inline bool try_broadcast_param_size(int64_t q_size,
134-
int64_t k_size,
135-
int64_t v_size,
135+
inline bool try_broadcast_param_size(const c10::SymInt q_size,
136+
const c10::SymInt k_size,
137+
const c10::SymInt v_size,
136138
c10::string_view param_name,
137139
bool debug) {
138140
auto max_size = std::max({q_size, k_size, v_size});
@@ -329,9 +331,9 @@ inline bool check_safe_kv_broadcast(at::Tensor param, bool debug){
329331
inline bool check_batch_size_and_num_heads(sdp_params params, bool debug) {
330332
// This is expected to be called after check_tensor_shapes ensuring that the size()
331333
// calls won't error since the inputs are all 4 dimensional
332-
auto q_batch_size = params.query.size(0);
333-
auto k_batch_size = params.key.size(0);
334-
auto v_batch_size = params.value.size(0);
334+
auto q_batch_size = params.query.sym_size(0);
335+
auto k_batch_size = params.key.sym_size(0);
336+
auto v_batch_size = params.value.sym_size(0);
335337

336338
bool has_nested_input = check_for_nested_inputs(params);
337339
bool same_batch_size = q_batch_size == k_batch_size && q_batch_size == v_batch_size;
@@ -362,9 +364,9 @@ inline bool check_batch_size_and_num_heads(sdp_params params, bool debug) {
362364
return broadcastable_batch_size;
363365
}
364366

365-
auto q_num_heads = params.query.size(1);
366-
auto k_num_heads = params.key.size(1);
367-
auto v_num_heads = params.value.size(1);
367+
auto q_num_heads = params.query.sym_size(1);
368+
auto k_num_heads = params.key.sym_size(1);
369+
auto v_num_heads = params.value.sym_size(1);
368370
bool same_num_heads = q_num_heads == k_num_heads && q_num_heads == v_num_heads;
369371

370372
if (!(same_batch_size && same_num_heads)) {
@@ -385,9 +387,9 @@ inline bool check_batch_size_and_num_heads(sdp_params params, bool debug) {
385387
}
386388

387389
inline bool check_head_dim_size(sdp_params params, bool debug) {
388-
const int64_t query_size_last = params.query.size(-1);
389-
const int64_t key_size_last = params.key.size(-1);
390-
const int64_t value_size_last = params.value.size(-1);
390+
const auto query_size_last = params.query.sym_size(-1);
391+
const auto key_size_last = params.key.sym_size(-1);
392+
const auto value_size_last = params.value.sym_size(-1);
391393
if (!(query_size_last == key_size_last &&
392394
query_size_last == value_size_last && query_size_last % 8 == 0 &&
393395
query_size_last <= 128 && value_size_last % 8 == 0 &&
@@ -398,9 +400,9 @@ inline bool check_head_dim_size(sdp_params params, bool debug) {
398400
" Got Query.size(-1): ",
399401
query_size_last,
400402
", Key.size(-1): ",
401-
params.key.size(-1),
403+
params.key.sym_size(-1),
402404
", Value.size(-1): ",
403-
params.value.size(-1),
405+
params.value.sym_size(-1),
404406
" instead.");
405407
}
406408
return false;
@@ -437,10 +439,10 @@ inline int64_t minimum_gemm_alignment(sdp_params params) {
437439
}
438440

439441
inline bool check_head_dim_size_mem_efficient(sdp_params params, bool debug) {
440-
const int64_t query_size_last = params.query.size(-1);
441-
const int64_t value_size_last = params.value.size(-1);
442+
const auto query_size_last = params.query.sym_size(-1);
443+
const auto value_size_last = params.value.sym_size(-1);
442444
const int64_t alignment = minimum_gemm_alignment(params);
443-
if (!(query_size_last == params.key.size(-1) &&
445+
if (!(query_size_last == params.key.sym_size(-1) &&
444446
query_size_last % alignment == 0 && query_size_last > 0 &&
445447
value_size_last % alignment == 0 && value_size_last > 0)) {
446448
if (debug) {
@@ -451,9 +453,9 @@ inline bool check_head_dim_size_mem_efficient(sdp_params params, bool debug) {
451453
"Got Query.size(-1): ",
452454
query_size_last,
453455
", Key.size(-1): ",
454-
params.key.size(-1),
456+
params.key.sym_size(-1),
455457
", Value.size(-1): ",
456-
params.value.size(-1),
458+
params.value.sym_size(-1),
457459
" instead.");
458460
}
459461
return false;
@@ -527,7 +529,7 @@ inline bool check_gpu_sm86_head_dim_128(sdp_params params, bool debug) {
527529
// on sm86 when head_dim is 128.
528530
auto dprops = at::cuda::getCurrentDeviceProperties();
529531
bool is_sm86 = (dprops->major == 8) && (dprops->minor == 6);
530-
if (is_sm86 && (params.query.size(-1) == 128)) {
532+
if (is_sm86 && (params.query.sym_size(-1) == 128)) {
531533
if (debug) {
532534
TORCH_WARN(
533535
"Memory Efficient Attention does not currently support head_dim == 128 on sm86",

test/dynamo/test_dynamic_shapes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
test_classes = {}
2626

2727
ALL_DYNAMIC_XFAILS = {
28-
"MiscTests": [
29-
"test_parsing_sdpa",
30-
],
28+
"MiscTests": [],
3129
"ReproTests": [
3230
# Could not infer dtype of torch._C.SymIntNode
3331
"test_convert_boxes_to_pooler_format",

0 commit comments

Comments
 (0)