Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,6 @@ Do not modify directly.*
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|Attention|*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *in* nonpad_kv_seqlen:**tensor(int64)**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**<br><br>or<br><br>*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**|23+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[11, 18]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ struct AttentionParameters {
float mask_filter_value;
float scale;
bool use_tf32;
bool is_output_bnsh; // whether the output format is BNSH
AttentionMaskType mask_type;
AttentionQkvFormat qkv_format;
};
Expand Down
21 changes: 9 additions & 12 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -772,23 +772,20 @@ Status UnfusedAttention(

DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);

// compute R*V (as V*R), and store in output or temp workspace depending on whether transpose is needed
// For 4D input (BNSH), write directly to output. For 3D input (BSNH), write to temp then transpose.
T* temp_output = parameters.is_output_bnsh ? data.output : data.q;
// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
T* temp_output = data.q;
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
v_head_size, sequence_length, total_sequence_length,
&one, data.v, v_head_size, present_size_per_batch_v,
scratch2, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

if (!parameters.is_output_bnsh) {
// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
device_prop.maxThreadsPerBlock, false, temp_output, data.output));
}
// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
device_prop.maxThreadsPerBlock, false, temp_output, data.output);
DUMP_TENSOR_D("Attention Output", data.output, batch_size, sequence_length, num_heads, v_head_size);
return Status::OK();
return result;
}

template <typename T>
Expand Down Expand Up @@ -963,7 +960,7 @@ Status QkvToContext(
auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
const int kv_sequence_length = parameters.kv_sequence_length;
const int sequence_length = parameters.sequence_length;
const int total_sequence_length = parameters.total_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
Expand All @@ -984,12 +981,12 @@ Status QkvToContext(

if (!parameters.past_present_share_buffer) {
ORT_RETURN_IF_ERROR(ConcatPastToPresent<T>(batch_size, num_heads, qk_head_size, v_head_size,
kv_sequence_length, total_sequence_length,
sequence_length, total_sequence_length,
stream, max_threads_per_block, data));

} else { // past_present_share_buffer
ORT_RETURN_IF_ERROR(PastPresentBufferShare<T>(batch_size, num_heads, qk_head_size, v_head_size,
kv_sequence_length, fused_runner,
sequence_length, fused_runner,
parameters, data, stream, max_threads_per_block));
}

Expand Down
250 changes: 106 additions & 144 deletions onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block) {
assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
Expand All @@ -261,96 +262,82 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
assert(!parameters.is_unidirectional);
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data));

if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
// 3D inputs in BSNH format (will be transposed)
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;

if (data.fused_cross_attention_kernel != nullptr) {
assert(qk_head_size == v_head_size);
assert(data.attention_bias == nullptr);
assert(data.mask_index == nullptr);
assert(parameters.hidden_size == parameters.v_hidden_size);

// For fused cross attention, besides adding bias, K and V needed to be packed:
// Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH)
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
data.v = nullptr;
data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
} else if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
if (data.bias != nullptr) {
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.key, data.value, data.q, data.k, data.v);
} else {
data.q = const_cast<T*>(data.query);
data.k = const_cast<T*>(data.key);
data.v = const_cast<T*>(data.value);
}
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;

data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else if (data.fused_runner != nullptr) {
assert(qk_head_size == v_head_size);
assert(data.attention_bias == nullptr);
if (data.fused_cross_attention_kernel != nullptr) {
assert(qk_head_size == v_head_size);
assert(data.attention_bias == nullptr);
assert(data.mask_index == nullptr);
assert(parameters.hidden_size == parameters.v_hidden_size);

// For fused cross attention, besides adding bias, K and V needed to be packed:
// Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH)
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
data.v = nullptr;
data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
} else if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
if (data.bias != nullptr) {
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.key, data.value, data.q, data.k, data.v);
} else {
data.q = const_cast<T*>(data.query);
data.k = const_cast<T*>(data.key);
data.v = const_cast<T*>(data.value);
}

// Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length);
data.k = nullptr;
data.v = nullptr;
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else if (data.fused_runner != nullptr) {
assert(qk_head_size == v_head_size);
assert(data.attention_bias == nullptr);

data.qkv_format = AttentionQkvFormat::QKV_BSN3H;
} else { // unfused kernel
assert(data.IsUnfused());
// Query (BxSxNxH) => Q (BxNxSxH)
constexpr int format = 0;
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, data.q,
true, -1);

// Key (BxLxNxH) => K (BxNxLxH)
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k,
true, -1);

// Value (BxLxNxH_v) => K (BxNxLxH_v)
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, v_head_size,
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v,
true, -1);

data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
} else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
// Currently, 4D inputs are only supported in unfused kernel for Attention-23.
// Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length);
data.k = nullptr;
data.v = nullptr;

data.qkv_format = AttentionQkvFormat::QKV_BSN3H;
} else { // unfused kernel
assert(data.IsUnfused());
// Attention-23 does not support bias with Q_K_V_BNSH format.
assert(data.bias == nullptr);
// No need to transpose since QKV is already in BNSH format.
data.q = const_cast<T*>(data.query);
data.k = const_cast<T*>(data.key);
data.v = const_cast<T*>(data.value);
} else {
ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
// Query (BxSxNxH) => Q (BxNxSxH)
constexpr int format = 0;
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, data.q,
true, -1);

// Key (BxLxNxH) => K (BxNxLxH)
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k,
true, -1);

// Value (BxLxNxH_v) => K (BxNxLxH_v)
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, v_head_size,
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v,
true, -1);

data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}

return Status::OK();
Expand All @@ -373,6 +360,7 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block) {
assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
Expand All @@ -385,53 +373,42 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
data.past_key != nullptr && data.past_value != nullptr);
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_NoBias(data));

const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;

// When there is no past state and there is present state, we output K and V directly to present state.
if (data.past_key == nullptr && data.present_key != nullptr) {
data.k = data.present_key;
data.v = data.present_value;
}
if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
// 3D inputs in BSNH format (will be transposed)
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;

if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.use_lean_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Use oiginal Query (BSNH) since there is no bias.
data.q = const_cast<T*>(data.query);

// Key (BxLxNxH) => K (BxNxLxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.k));
// Value (BxLxNxH) => V (BxNxLxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.v));
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
} else { // unfused kernel
assert(data.IsUnfused());
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, data.q));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.k));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.v));
data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
} else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
// Currently, 4D inputs are only supported in unfused kernel for Attention-23.
assert(data.IsUnfused());
// No need to transpose since QKV is already in BNSH format.
if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.use_lean_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Use oiginal Query (BSNH) since there is no bias.
data.q = const_cast<T*>(data.query);
data.k = const_cast<T*>(data.key);
data.v = const_cast<T*>(data.value);
} else {
ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);

// Key (BxLxNxH) => K (BxNxLxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.k));
// Value (BxLxNxH) => V (BxNxLxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.v));
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
} else { // unfused kernel
assert(data.IsUnfused());
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, data.q));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.k));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.v));
data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}

return Status::OK();
Expand Down Expand Up @@ -693,27 +670,14 @@ Status PrepareQkv_MultiHeadAttention(contrib::AttentionParameters& parameters,
case AttentionQkvFormat::Q_K_V_BSNH:
if (data.past_key != nullptr || data.present_key != nullptr) {
if (data.bias == nullptr) {
DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_NoBias");
DUMP_STRING("PrepareQkv_MHA_WithPast_NoBias");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block));
} else {
DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_Bias");
DUMP_STRING("PrepareQkv_MHA_WithPast_Bias");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_Bias(parameters, data, stream, max_threads_per_block));
}
} else { // no past state
DUMP_STRING("PrepareQkv(3D)_MHA_NoPast");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block));
}
break;
case AttentionQkvFormat::Q_K_V_BNSH:
if (data.past_key != nullptr || data.present_key != nullptr) {
if (data.bias == nullptr) {
DUMP_STRING("PrepareQkv(4D)_MHA_WithPast_NoBias");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block));
} else {
ORT_THROW("Q_K_V_BNSH format with bias is not supported.");
}
} else { // no past state
DUMP_STRING("PrepareQkv(4D)_MHA_NoPast");
DUMP_STRING("PrepareQkv_MHA_NoPast");
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block));
}
break;
Expand Down Expand Up @@ -744,8 +708,6 @@ bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData<T>&
} else { // no past state
return NoQkvWorkspace_MHA_NoPast(data);
}
case AttentionQkvFormat::Q_K_V_BNSH:
return false; // currently no scenario needs no workspace
default:
ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
}
Expand Down
Loading
Loading