Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Flash Attention 2.5.6 for disc backend #4

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions test/test_flash_attention_backward.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import sys
import unittest

Expand Down Expand Up @@ -150,15 +149,13 @@ def test_flash_attn_gqa_backward_fp16(self):
self._backward_internal(torch.float16, n_heads_kv=int(N_HEADS // 2))

def test_flash_attn_gqa_backward_bf16(self):
if not os.environ.get('DISC_DEVICE'):
self._backward_internal(torch.bfloat16, n_heads_kv=int(N_HEADS // 2))
self._backward_internal(torch.bfloat16, n_heads_kv=int(N_HEADS // 2))

def test_flash_attn_backward_fp16(self):
self._backward_internal(torch.float16, n_heads_kv=N_HEADS)

def test_flash_attn_backward_bf16(self):
if not os.environ.get('DISC_DEVICE'):
self._backward_internal(torch.bfloat16, n_heads_kv=N_HEADS)
self._backward_internal(torch.bfloat16, n_heads_kv=N_HEADS)

def test_flash_attn_gqa_backward_fp16_alibi(self):
self._backward_internal(
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/flash_attention_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(int batch_size, int num_heads, int seqlen_q,
xla::PrimitiveType::F32, {batch_size, num_heads, seqlen_q});
xla::Shape out_shape = GetXlaShape(q);
xla::Shape rng_state_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::U64, {2});
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {2});
return xla::ShapeUtil::MakeTupleShape(
{softmax_lse_shape, out_shape, rng_state_shape});
}
Expand Down
219 changes: 148 additions & 71 deletions torch_xla/csrc/runtime/disc/custom_call_flash_attention_backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ namespace tao {
namespace ral {

DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16");
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");

struct FlashAttentionBackwardParams {
using index_t = uint32_t;
Expand Down Expand Up @@ -57,6 +56,7 @@ struct FlashAttentionBackwardParams {
// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;

int total_q;
int total_k;

// The scaling factors for the kernel.
Expand All @@ -73,6 +73,12 @@ struct FlashAttentionBackwardParams {

bool is_bf16;
bool is_causal;
int window_size_left;
int window_size_right;
int alibi_slopes_batch_stride;
bool enable_alibi_slopes;
bool is_seqlens_k_cumulative;
int num_splits;

// Backward specific params
index_t do_batch_stride;
Expand All @@ -88,9 +94,11 @@ struct FlashAttentionBackwardParams {
index_t dk_head_stride;
index_t dv_head_stride;

bool deterministic;

void FromString(const std::string& str) {
std::vector<std::string> params_list = absl::StrSplit(str, "|");
TORCH_CHECK(params_list.size() == 43);
TORCH_CHECK(params_list.size() == 51);

// Forward specific param
absl::SimpleAtoi(params_list[0], &this->q_batch_stride);
Expand All @@ -102,67 +110,61 @@ struct FlashAttentionBackwardParams {
absl::SimpleAtoi(params_list[6], &this->q_head_stride);
absl::SimpleAtoi(params_list[7], &this->k_head_stride);
absl::SimpleAtoi(params_list[8], &this->v_head_stride);
absl::SimpleAtoi(params_list[9], &this->total_k);
absl::SimpleAtoi(params_list[10], &this->h);
absl::SimpleAtoi(params_list[11], &this->h_k);
absl::SimpleAtoi(params_list[12], &this->h_h_k_ratio);
absl::SimpleAtoi(params_list[13], &this->o_batch_stride);
absl::SimpleAtoi(params_list[14], &this->o_row_stride);
absl::SimpleAtoi(params_list[15], &this->o_head_stride);
absl::SimpleAtoi(params_list[16], &this->b);
absl::SimpleAtoi(params_list[17], &this->seqlen_q);
absl::SimpleAtoi(params_list[18], &this->seqlen_k);
absl::SimpleAtoi(params_list[19], &this->d);
absl::SimpleAtoi(params_list[20], &this->seqlen_q_rounded);
absl::SimpleAtoi(params_list[21], &this->seqlen_k_rounded);
absl::SimpleAtoi(params_list[22], &this->d_rounded);
absl::SimpleAtof(params_list[23], &this->scale_softmax);
absl::SimpleAtof(params_list[24], &this->scale_softmax_log2);
absl::SimpleAtof(params_list[25], &this->p_dropout);
absl::SimpleAtoi(params_list[9], &this->total_q);
absl::SimpleAtoi(params_list[10], &this->total_k);
absl::SimpleAtoi(params_list[11], &this->h);
absl::SimpleAtoi(params_list[12], &this->h_k);
absl::SimpleAtoi(params_list[13], &this->h_h_k_ratio);
absl::SimpleAtoi(params_list[14], &this->o_batch_stride);
absl::SimpleAtoi(params_list[15], &this->o_row_stride);
absl::SimpleAtoi(params_list[16], &this->o_head_stride);
absl::SimpleAtoi(params_list[17], &this->b);
absl::SimpleAtoi(params_list[18], &this->seqlen_q);
absl::SimpleAtoi(params_list[19], &this->seqlen_k);
absl::SimpleAtoi(params_list[20], &this->d);
absl::SimpleAtoi(params_list[21], &this->seqlen_q_rounded);
absl::SimpleAtoi(params_list[22], &this->seqlen_k_rounded);
absl::SimpleAtoi(params_list[23], &this->d_rounded);
absl::SimpleAtof(params_list[24], &this->scale_softmax);
absl::SimpleAtof(params_list[25], &this->scale_softmax_log2);
absl::SimpleAtof(params_list[26], &this->p_dropout);
uint32_t tmp;
absl::SimpleAtoi(params_list[26], &tmp);
absl::SimpleAtoi(params_list[27], &tmp);
this->p_dropout_in_uint8_t = uint8_t(tmp);
absl::SimpleAtof(params_list[27], &this->rp_dropout);
absl::SimpleAtof(params_list[28], &this->scale_softmax_rp_dropout);
absl::SimpleAtob(params_list[29], &this->is_bf16);
absl::SimpleAtob(params_list[30], &this->is_causal);
absl::SimpleAtof(params_list[28], &this->rp_dropout);
absl::SimpleAtof(params_list[29], &this->scale_softmax_rp_dropout);
absl::SimpleAtob(params_list[30], &this->is_bf16);
absl::SimpleAtob(params_list[31], &this->is_causal);
absl::SimpleAtoi(params_list[32], &this->window_size_left);
absl::SimpleAtoi(params_list[33], &this->window_size_right);
absl::SimpleAtoi(params_list[34], &this->alibi_slopes_batch_stride);
absl::SimpleAtob(params_list[35], &this->is_seqlens_k_cumulative);
absl::SimpleAtoi(params_list[36], &this->num_splits);
absl::SimpleAtob(params_list[37], &this->enable_alibi_slopes);

// backward specific params
absl::SimpleAtoi(params_list[31], &this->do_batch_stride);
absl::SimpleAtoi(params_list[32], &this->do_row_stride);
absl::SimpleAtoi(params_list[33], &this->do_head_stride);
absl::SimpleAtoi(params_list[34], &this->dq_batch_stride);
absl::SimpleAtoi(params_list[35], &this->dk_batch_stride);
absl::SimpleAtoi(params_list[36], &this->dv_batch_stride);
absl::SimpleAtoi(params_list[37], &this->dq_row_stride);
absl::SimpleAtoi(params_list[38], &this->dk_row_stride);
absl::SimpleAtoi(params_list[39], &this->dv_row_stride);
absl::SimpleAtoi(params_list[40], &this->dq_head_stride);
absl::SimpleAtoi(params_list[41], &this->dk_head_stride);
absl::SimpleAtoi(params_list[42], &this->dv_head_stride);
const int offset = 38; // FlashAttentionForwardParams has 38 variables
absl::SimpleAtoi(params_list[offset + 0], &this->do_batch_stride);
absl::SimpleAtoi(params_list[offset + 1], &this->do_row_stride);
absl::SimpleAtoi(params_list[offset + 2], &this->do_head_stride);
absl::SimpleAtoi(params_list[offset + 3], &this->dq_batch_stride);
absl::SimpleAtoi(params_list[offset + 4], &this->dk_batch_stride);
absl::SimpleAtoi(params_list[offset + 5], &this->dv_batch_stride);
absl::SimpleAtoi(params_list[offset + 6], &this->dq_row_stride);
absl::SimpleAtoi(params_list[offset + 7], &this->dk_row_stride);
absl::SimpleAtoi(params_list[offset + 8], &this->dv_row_stride);
absl::SimpleAtoi(params_list[offset + 9], &this->dq_head_stride);
absl::SimpleAtoi(params_list[offset + 10], &this->dk_head_stride);
absl::SimpleAtoi(params_list[offset + 11], &this->dv_head_stride);
absl::SimpleAtob(params_list[offset + 12], &this->deterministic);
}
};

void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream,
const bool configure) {
FP16_SWITCH(!params.is_bf16, [&] {
if (params.d <= 32) {
run_mha_bwd_<elem_type, 32>(params, stream, configure);
} else if (params.d <= 64) {
run_mha_bwd_<elem_type, 64>(params, stream, configure);
} else if (params.d <= 96) {
run_mha_bwd_<elem_type, 96>(params, stream, configure);
} else if (params.d <= 128) {
run_mha_bwd_<elem_type, 128>(params, stream, configure);
} else if (params.d <= 160) {
run_mha_bwd_<elem_type, 160>(params, stream, configure);
} else if (params.d <= 192) {
run_mha_bwd_<elem_type, 192>(params, stream, configure);
} else if (params.d <= 224) {
run_mha_bwd_<elem_type, 224>(params, stream, configure);
} else if (params.d <= 256) {
run_mha_bwd_<elem_type, 256>(params, stream, configure);
}
HEADDIM_SWITCH(params.d,
[&] { run_mha_bwd_<elem_type, kHeadDim>(params, stream); });
});
}

Expand All @@ -175,18 +177,21 @@ void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream,
// buffers[5] = softmax_lse
// buffers[6] = cu_seqlens_q
// buffers[7] = cu_seqlens_k
// buffers[8] = dq // this is output
// buffers[9] = dk // this is output
// buffers[10] = dv // this is output
// buffers[11] = softmax_d // this is output
// buffers[8] = rng_state
// buffers[9] = alibi_slopes
// buffers[10] = dq // this is output
// buffers[11] = dk // this is output
// buffers[12] = dv // this is output
// buffers[13] = softmax_d // this is output
template <typename T_IN, typename SOFT_MAX_TYPE, int M>
std::tuple<MemRefType<T_IN, M>, MemRefType<T_IN, M>, MemRefType<T_IN, M>,
MemRefType<SOFT_MAX_TYPE, M>>
custom_call_flash_attention_backward(
custom_call_flash_attention_backward_impl(
ExecutionContext* ctx, void* stream_handle, MemRefType<T_IN, M> dout,
MemRefType<T_IN, M> q, MemRefType<T_IN, M> k, MemRefType<T_IN, M> v,
MemRefType<T_IN, M> out, MemRefType<SOFT_MAX_TYPE, M> softmax_lse,
MemRefType<int32_t, 1> seqlens_q, MemRefType<int32_t, 1> seqlens_k,
MemRefType<int64_t, 1> rng_state, void* alibi_slopes_ptr,
void* customAttrs) {
auto attr = getOrParsePDLAttr(ctx, customAttrs,
"custom_call_flash_attention_backward");
Expand Down Expand Up @@ -236,7 +241,6 @@ custom_call_flash_attention_backward(
memset(&launch_params, 0, sizeof(launch_params));

launch_params.is_bf16 = params.is_bf16;
launch_params.is_bf16 = true;

// Set the pointers and strides.
launch_params.q_ptr = q.data;
Expand All @@ -256,6 +260,9 @@ custom_call_flash_attention_backward(
launch_params.cu_seqlens_q = static_cast<int*>(seqlens_q.data);
launch_params.cu_seqlens_k = static_cast<int*>(seqlens_k.data);

launch_params.alibi_slopes_ptr = alibi_slopes_ptr;
launch_params.alibi_slopes_batch_stride = params.alibi_slopes_batch_stride;

// P = softmax(QK^T)
launch_params.p_ptr = nullptr; // no softmax returned always

Expand Down Expand Up @@ -284,6 +291,10 @@ custom_call_flash_attention_backward(
launch_params.scale_softmax_rp_dropout = params.scale_softmax_rp_dropout;

launch_params.is_causal = params.is_causal;
launch_params.window_size_left = params.window_size_left;
launch_params.window_size_right = params.window_size_right;

launch_params.is_seqlens_k_cumulative = true;

launch_params.do_ptr = dout.data;
launch_params.do_row_stride = params.do_row_stride;
Expand All @@ -305,10 +316,19 @@ custom_call_flash_attention_backward(
auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA);
at::Tensor dq_accum;
if (loop) {
dq_accum =
torch::empty({launch_params.b, launch_params.h,
launch_params.seqlen_q_rounded, launch_params.d_rounded},
opts.dtype(at::kFloat));
if (!params.deterministic) {
dq_accum = torch::empty({params.total_q + 128 * launch_params.b,
launch_params.h, launch_params.d_rounded},
opts.dtype(at::kFloat));
} else {
auto dprops = at::cuda::getCurrentDeviceProperties();
const int nsplits = (dprops->multiProcessorCount +
launch_params.b * launch_params.h - 1) /
(launch_params.b * launch_params.h);
dq_accum = torch::zeros({nsplits, params.total_q + 128 * launch_params.b,
launch_params.h, launch_params.d_rounded},
opts.dtype(at::kFloat));
}
}

at::Tensor dk = torch::from_blob(
Expand Down Expand Up @@ -344,6 +364,10 @@ custom_call_flash_attention_backward(
// Softmax sum
launch_params.dsoftmax_sum = dsoftmax.data;

launch_params.deterministic = params.deterministic;
launch_params.dq_accum_split_stride =
!launch_params.deterministic ? 0 : dq_accum.stride(0);

auto launch = &run_mha_bwd;

auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
Expand All @@ -353,11 +377,11 @@ custom_call_flash_attention_backward(
int64_t counter_offset = launch_params.b * launch_params.h * 32;

bool is_dropout = (1.f - launch_params.p_dropout) > 0.0;
if (is_dropout) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
launch_params.philox_args = gen->philox_cuda_state(counter_offset);
}
// TODO(wenting.swt): According to the implementation in
// `flash_attn_varlen_func` of flash-attn v2.5.6, the forward generates
// `rng_state` which is passed as ctx to the backward. Hence, for simplifying
// the logic, the redundant branch where `rng_state` is None has been omitted.
launch_params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data);

launch(launch_params, gpu_stream, /*configure=*/false);

Expand All @@ -378,12 +402,65 @@ custom_call_flash_attention_backward(
return std::make_tuple(dq_res, dk_res, dv_res, dsoftmax);
}

template <typename T_IN, typename SOFT_MAX_TYPE, int M>
std::tuple<MemRefType<T_IN, M>, MemRefType<T_IN, M>, MemRefType<T_IN, M>,
MemRefType<SOFT_MAX_TYPE, M>>
custom_call_flash_attention_backward_noalibi(
ExecutionContext* ctx, void* stream_handle, MemRefType<T_IN, M> dout,
MemRefType<T_IN, M> q, MemRefType<T_IN, M> k, MemRefType<T_IN, M> v,
MemRefType<T_IN, M> out, MemRefType<SOFT_MAX_TYPE, M> softmax_lse,
MemRefType<int32_t, 1> seqlens_q, MemRefType<int32_t, 1> seqlens_k,
MemRefType<int64_t, 1> rng_state, void* customAttrs) {
return custom_call_flash_attention_backward_impl<T_IN, SOFT_MAX_TYPE, M>(
ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k,
rng_state, nullptr, customAttrs);
}

template <typename T_IN, typename SOFT_MAX_TYPE, int M>
std::tuple<MemRefType<T_IN, M>, MemRefType<T_IN, M>, MemRefType<T_IN, M>,
MemRefType<SOFT_MAX_TYPE, M>>
custom_call_flash_attention_backward_alibi_v1(
ExecutionContext* ctx, void* stream_handle, MemRefType<T_IN, M> dout,
MemRefType<T_IN, M> q, MemRefType<T_IN, M> k, MemRefType<T_IN, M> v,
MemRefType<T_IN, M> out, MemRefType<SOFT_MAX_TYPE, M> softmax_lse,
MemRefType<int32_t, 1> seqlens_q, MemRefType<int32_t, 1> seqlens_k,
MemRefType<int64_t, 1> rng_state, MemRefType<float, 1> alibi_slopes,
void* customAttrs) {
return custom_call_flash_attention_backward_impl<T_IN, SOFT_MAX_TYPE, M>(
ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k,
rng_state, alibi_slopes.data, customAttrs);
}

template <typename T_IN, typename SOFT_MAX_TYPE, int M>
std::tuple<MemRefType<T_IN, M>, MemRefType<T_IN, M>, MemRefType<T_IN, M>,
MemRefType<SOFT_MAX_TYPE, M>>
custom_call_flash_attention_backward_alibi_v2(
ExecutionContext* ctx, void* stream_handle, MemRefType<T_IN, M> dout,
MemRefType<T_IN, M> q, MemRefType<T_IN, M> k, MemRefType<T_IN, M> v,
MemRefType<T_IN, M> out, MemRefType<SOFT_MAX_TYPE, M> softmax_lse,
MemRefType<int32_t, 1> seqlens_q, MemRefType<int32_t, 1> seqlens_k,
MemRefType<int64_t, 1> rng_state, MemRefType<float, 2> alibi_slopes,
void* customAttrs) {
return custom_call_flash_attention_backward_impl<T_IN, SOFT_MAX_TYPE, M>(
ctx, stream_handle, dout, q, k, v, out, softmax_lse, seqlens_q, seqlens_k,
rng_state, alibi_slopes.data, customAttrs);
}

TAO_RAL_API(
"custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward_noalibi<Eigen::half, float, 3>);
TAO_RAL_API(
"custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward_alibi_v1<Eigen::half, float, 3>);
TAO_RAL_API(
"custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward_alibi_v2<Eigen::half, float, 3>);
TAO_RAL_API("custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward<float, float, 3>);
custom_call_flash_attention_backward_noalibi<bfloat16, float, 3>);
TAO_RAL_API("custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward<Eigen::half, float, 3>);
custom_call_flash_attention_backward_alibi_v1<bfloat16, float, 3>);
TAO_RAL_API("custom_call_flash_attention_backward", "gpu",
custom_call_flash_attention_backward<Eigen::bfloat16, float, 3>);
custom_call_flash_attention_backward_alibi_v2<bfloat16, float, 3>);

} // namespace ral
} // namespace tao
Loading
Loading