Skip to content

Commit

Permalink
align code with flash_api.cpp when num_splits > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
Seventeen17 committed Nov 19, 2024
1 parent ee9724f commit 2f3198c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test/test_flash_attention_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setup_env():
@pytest.mark.parametrize("alibi", [False, True])
@pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("d", [8])
@pytest.mark.parametrize("d", [256])
@pytest.mark.parametrize("softmax_scale", [0.25])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
Expand Down
14 changes: 13 additions & 1 deletion torch_xla/csrc/ops/flash_attention_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,25 @@ void custom_call_flash_attention_forward(cudaStream_t stream, void** buffers,
// inference.
const int num_m_blocks = (params.seqlen_q + 64 - 1) / 64;
launch_params.num_splits = 0;
if (1 - params.p_dropout == 0.0f) { // SplitKV is not implemented for dropout
if (params.p_dropout == 1.0f) { // SplitKV is not implemented for dropout
if (launch_params.num_splits < 1) {
auto dprops = at::cuda::getCurrentDeviceProperties();
launch_params.num_splits =
num_splits_heuristic(params.b * params.h * num_m_blocks,
dprops->multiProcessorCount, num_n_blocks, 128);
}
if (launch_params.num_splits > 1) {
at::Tensor softmax_lse_accum =
torch::empty({launch_params.num_splits, params.b, params.h,
launch_params.seqlen_q},
opts.dtype(at::kFloat));
at::Tensor out_accum =
torch::empty({params.num_splits, params.b, params.h,
launch_params.seqlen_q, launch_params.d_rounded},
opts.dtype(at::kFloat));
launch_params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
launch_params.oaccum_ptr = out_accum.data_ptr();
}
}

int64_t counter_offset = params.b * params.h * 32;
Expand Down

0 comments on commit 2f3198c

Please sign in to comment.