From 3c559483ef9d87b05a59b6fc70577ff891cd5930 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 24 Sep 2024 03:12:12 +0000 Subject: [PATCH] fix tests --- .../contrib_ops/cpu/bert/attention_utils.cc | 77 -------- .../contrib_ops/cpu/bert/gqa_attention_base.h | 63 ++++--- .../cpu/bert/group_query_attention.cc | 22 +-- .../test/python/transformers/test_gqa_cpu.py | 170 +++++++++--------- 4 files changed, 141 insertions(+), 191 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc index 4b83ec043335e..c8fe9c77d8ff8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc @@ -10,21 +10,6 @@ using onnxruntime::concurrency::ThreadPool; namespace onnxruntime { namespace contrib { -namespace { -template -struct EigenType; - -template <> -struct EigenType { - using Type = float; -}; - -template <> -struct EigenType { - using Type = Eigen::half; -}; -} - // Reshape Q/K/V from BxSxD to BxSxNxH inline Status Reshape_BSD_to_BSNH(Tensor* qkv, int batch_size, @@ -64,42 +49,12 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat ProcessBroadcastSpanFuncs add_funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); - // auto num_elements = per_iter_bh.NumOutputElements(); - - // const auto* input_1 = reinterpret_cast::Type*>(per_iter_bh.EigenInput1().data()); - // ConstEigenVectorArrayMap::Type> input_1_vec_map(input_1, num_elements); - - // auto* output = reinterpret_cast::Type*>(per_iter_bh.OutputEigen().data()); - // EigenVectorArrayMap::Type> output_vec_map(output, num_elements); - - // output_vec_map = input_1_vec_map + static_cast::Type>(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); - // auto num_elements = per_iter_bh.NumOutputElements(); - - // const auto* input_0 = reinterpret_cast::Type*>(per_iter_bh.EigenInput0().data()); - // ConstEigenVectorArrayMap::Type> input_0_vec_map(input_0, num_elements); - - // auto* output = reinterpret_cast::Type*>(per_iter_bh.OutputEigen().data()); - // EigenVectorArrayMap::Type> output_vec_map(output, num_elements); - - // output_vec_map = input_0_vec_map + static_cast::Type>(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); - // auto num_elements = per_iter_bh.NumOutputElements(); - - // const auto* input_0 = reinterpret_cast::Type*>(per_iter_bh.EigenInput0().data()); - // ConstEigenVectorArrayMap::Type> input_0_vec_map(input_0, num_elements); - - // const auto* input_1 = reinterpret_cast::Type*>(per_iter_bh.EigenInput1().data()); - // ConstEigenVectorArrayMap::Type> input_1_vec_map(input_1, num_elements); - - // auto* output = reinterpret_cast::Type*>(per_iter_bh.OutputEigen().data()); - // EigenVectorArrayMap::Type> output_vec_map(output, num_elements); - - // output_vec_map = input_0_vec_map + input_1_vec_map; }}; // For element-wise add // Allocate space for output of Q(BS, D) + bias(D) @@ -159,7 +114,6 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat return Status::OK(); } - // Add bias + reshape for each of Q/K/V // This is used in decoder_with_past when the sequence length is 1 template @@ -175,47 +129,16 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is OpKernelContext* context) { // Note: the comments below will refer to Q's dimensions for simplicity auto element_type = DataTypeImpl::GetType(); - //using eigen_type = typename EigenType::Type; constexpr size_t element_size = sizeof(T); ProcessBroadcastSpanFuncs add_funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); - // auto num_elements = per_iter_bh.NumOutputElements(); - - // const auto* input_1 = reinterpret_cast(per_iter_bh.EigenInput1().data()); - // ConstEigenVectorArrayMap input_1_vec_map(input_1, num_elements); - - // auto* output = reinterpret_cast(per_iter_bh.OutputEigen().data()); - // EigenVectorArrayMap output_vec_map(output, num_elements); - - // output_vec_map = input_1_vec_map + static_cast(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); - // auto num_elements = per_iter_bh.NumOutputElements(); - - // const auto* input_0 = reinterpret_cast(per_iter_bh.EigenInput0().data()); - // ConstEigenVectorArrayMap input_0_vec_map(input_0, num_elements); - - // auto* output = reinterpret_cast(per_iter_bh.OutputEigen().data()); - // EigenVectorArrayMap output_vec_map(output, num_elements); - - // output_vec_map = input_0_vec_map + static_cast(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); - // auto num_elements = per_iter_bh.NumOutputElements(); - - // const auto* input_0 = reinterpret_cast(per_iter_bh.EigenInput0().data()); - // ConstEigenVectorArrayMap input_0_vec_map(input_0, num_elements); - - // const auto* input_1 = reinterpret_cast(per_iter_bh.EigenInput1().data()); - // ConstEigenVectorArrayMap input_1_vec_map(input_1, num_elements); - - // auto* output = reinterpret_cast(per_iter_bh.OutputEigen().data()); - // EigenVectorArrayMap output_vec_map(output, num_elements); - - // output_vec_map = input_0_vec_map + input_1_vec_map; }}; // For element-wise add // Get Q's bias from combined bias diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 0099fb616aa28..ccaeb6654e286 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -93,7 +93,8 @@ class GQAAttentionBase { // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; - ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), + ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, + seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); @@ -132,7 +133,9 @@ class GQAAttentionBase { const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { - memset((void*)present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); + memset((void*)present_key, + 0, + batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } const size_t loop_len = batch_size * num_heads_; @@ -193,8 +196,8 @@ class GQAAttentionBase { if constexpr (std::is_same::value) { math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, - static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, - static_cast(present_buffer_sequence_length), nullptr); + static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, + output, static_cast(present_buffer_sequence_length), nullptr); } else { size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float); auto q_k_fp32 = allocator->Alloc(bytes); @@ -254,7 +257,7 @@ class GQAAttentionBase { template void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH - const float* attention_probs, // Attention probs with size BxNxSxT + const float* attention_probs, // Attention probs with size BxNxSxT const T* V, // V value with size BxN_kvxSxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor const size_t batch_size, // batch size @@ -279,7 +282,9 @@ class GQAAttentionBase { const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { - memset((void*)present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); + memset((void*)present_value, + 0, + batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } const size_t loop_len = batch_size * num_heads_; @@ -303,6 +308,13 @@ class GQAAttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; + size_t output_fp32_bytes = 0; + if constexpr (std::is_same::value) { + output_fp32_bytes = SafeInt(sequence_length) * batch_size * num_heads_ * head_size * sizeof(float); + } + auto output_fp32 = allocator->Alloc(output_fp32_bytes); + BufferUniquePtr scratch_buffer(output_fp32, BufferDeleter(allocator)); + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const size_t batch_index = i / num_heads_; @@ -323,32 +335,39 @@ class GQAAttentionBase { i / kv_num_heads_factor); } - T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; if constexpr (std::is_same::value) { - math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/ - attention_probs + attention_probs_offset, - static_cast(present_buffer_sequence_length), v, static_cast(head_size), - 0.0f /*beta*/, output_current, static_cast(hidden_size), nullptr); + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v, + static_cast(head_size), 0.0f /*beta*/, output_current, + static_cast(hidden_size), nullptr); } else { - size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float); - auto v_o_fp32 = allocator->Alloc(bytes); - BufferUniquePtr scratch_buffer(v_o_fp32, BufferDeleter(allocator)); + size_t bytes = head_size * total_seqlen * sizeof(float); + auto v_fp32 = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator)); - float* v_fp32_ptr = static_cast(v_o_fp32); + float* v_fp32_ptr = static_cast(v_fp32); MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seqlen); - float* output_fp32 = v_fp32_ptr + head_size * total_seqlen; - math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/ - attention_probs + attention_probs_offset, - static_cast(present_buffer_sequence_length), v_fp32_ptr, static_cast(head_size), - 0.0f /*beta*/, output_fp32, static_cast(hidden_size), nullptr); - - MlasConvertFloatToHalfBuffer(output_fp32, output_current, head_size * sequence_length); + float* output_fp32_current = static_cast(output_fp32) + + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v_fp32_ptr, + static_cast(head_size), 0.0f /*beta*/, output_fp32_current, + static_cast(hidden_size), nullptr); } } }); + + if constexpr (std::is_same::value) { + MlasConvertFloatToHalfBuffer(static_cast(output_fp32), + output, + SafeInt(sequence_length) * batch_size * num_heads_ * head_size); + } } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index d3218e7d6602b..a1ed35e54b008 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -22,17 +22,17 @@ namespace onnxruntime { namespace contrib { // These ops are internal-only, so register outside of onnx -#define REGISTER_KERNEL_TYPED(T) \ -ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GroupQueryAttention, \ - kMSDomain, \ - 1, \ - T, \ - kCpuExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ - GroupQueryAttention); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + GroupQueryAttention); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index dc21d4e4a5890..08ec5de328b9d 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -29,6 +29,12 @@ GREEN = "\033[32m" RESET = "\033[0m" +ORT_TYPE = TensorProto.FLOAT +TORCH_TYPE = torch.float16 if ORT_TYPE == TensorProto.FLOAT16 else torch.float32 +NUMPY_TYPE = numpy.float16 if ORT_TYPE == TensorProto.FLOAT16 else numpy.float32 +RTOL = 3e-2 if ORT_TYPE == TensorProto.FLOAT16 else 1e-3 +ATOL = RTOL + class Formats: BSNH = 0 @@ -186,7 +192,7 @@ def create_group_query_attention_graph_prompt( graph_input = [ helper.make_tensor_value_info( "query", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.q_sequence_length, @@ -212,7 +218,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length, @@ -221,7 +227,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length, @@ -233,7 +239,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "past_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -243,7 +249,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "past_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -256,7 +262,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "cos_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.buffer_sequence_length if share_buffer else config.kv_sequence_length, (math.floor(config.head_size / 16) * 16) // 2, @@ -264,7 +270,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "sin_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.buffer_sequence_length if share_buffer else config.kv_sequence_length, (math.floor(config.head_size / 16) * 16) // 2, @@ -275,12 +281,12 @@ def create_group_query_attention_graph_prompt( graph_output = [ helper.make_tensor_value_info( "output", - TensorProto.FLOAT, + ORT_TYPE, [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size], ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -290,7 +296,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -300,7 +306,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -310,7 +316,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -378,7 +384,7 @@ def create_group_query_attention_graph_past( graph_input = [ helper.make_tensor_value_info( "query", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -391,7 +397,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "past_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -401,7 +407,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "past_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -424,7 +430,7 @@ def create_group_query_attention_graph_past( graph_input += [ helper.make_tensor_value_info( "key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -433,7 +439,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -445,7 +451,7 @@ def create_group_query_attention_graph_past( graph_input += [ helper.make_tensor_value_info( "cos_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.kv_sequence_length + (0 if share_buffer else config.sequence_length), (math.floor(config.head_size / 16) * 16) // 2, @@ -453,7 +459,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "sin_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.kv_sequence_length + (0 if share_buffer else config.sequence_length), (math.floor(config.head_size / 16) * 16) // 2, @@ -464,12 +470,12 @@ def create_group_query_attention_graph_past( graph_output = [ helper.make_tensor_value_info( "output", - TensorProto.FLOAT, + ORT_TYPE, [config.batch_size, config.sequence_length, config.num_heads * config.head_size], ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -479,7 +485,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -641,7 +647,7 @@ def create_inputs(config: Config, kv_packed=False, qkv_packed=True): config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) key_padding_mask = generate_random_padding_mask( @@ -722,13 +728,13 @@ def gqa_prompt_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( - "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) io_binding.bind_input( "past_value", "cpu", 0, - numpy.float32, + NUMPY_TYPE, ort_inputs["past_value"].shape(), ort_inputs["past_value"].data_ptr(), ) @@ -835,13 +841,13 @@ def gqa_past_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( - "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) io_binding.bind_input( "past_value", "cpu", 0, - numpy.float32, + NUMPY_TYPE, ort_inputs["past_value"].shape(), ort_inputs["past_value"].data_ptr(), ) @@ -1017,9 +1023,11 @@ def attention_ref( attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @@ -1058,8 +1066,8 @@ def parity_check_gqa_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1067,7 +1075,7 @@ def parity_check_gqa_prompt( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1076,7 +1084,7 @@ def parity_check_gqa_prompt( config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1085,7 +1093,7 @@ def parity_check_gqa_prompt( config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1094,7 +1102,7 @@ def parity_check_gqa_prompt( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1103,7 +1111,7 @@ def parity_check_gqa_prompt( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1129,8 +1137,8 @@ def parity_check_gqa_prompt( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved @@ -1152,8 +1160,8 @@ def parity_check_gqa_prompt( kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE) + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE) k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded @@ -1218,11 +1226,11 @@ def parity_check_gqa_prompt( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "KV-buffer", @@ -1271,8 +1279,8 @@ def parity_check_gqa_prompt_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1280,7 +1288,7 @@ def parity_check_gqa_prompt_no_buff( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1289,7 +1297,7 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1298,7 +1306,7 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1321,8 +1329,8 @@ def parity_check_gqa_prompt_no_buff( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved @@ -1405,11 +1413,11 @@ def parity_check_gqa_prompt_no_buff( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "No buff", @@ -1458,8 +1466,8 @@ def parity_check_gqa_past( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1467,7 +1475,7 @@ def parity_check_gqa_past( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1476,7 +1484,7 @@ def parity_check_gqa_past( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1485,7 +1493,7 @@ def parity_check_gqa_past( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1494,7 +1502,7 @@ def parity_check_gqa_past( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1503,7 +1511,7 @@ def parity_check_gqa_past( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1534,8 +1542,8 @@ def parity_check_gqa_past( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved @@ -1624,11 +1632,11 @@ def parity_check_gqa_past( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "KV-buffer", @@ -1677,8 +1685,8 @@ def parity_check_gqa_past_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): torch.manual_seed(69) q = torch.randn( @@ -1687,7 +1695,7 @@ def parity_check_gqa_past_no_buff( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1696,7 +1704,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1705,7 +1713,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1714,7 +1722,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1723,7 +1731,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1759,8 +1767,8 @@ def parity_check_gqa_past_no_buff( angle = ( torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi ) - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved @@ -1849,7 +1857,7 @@ def parity_check_gqa_past_no_buff( out = out.detach().cpu().numpy() # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "NO buff", @@ -1983,8 +1991,8 @@ def test_gqa_past(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -1996,8 +2004,8 @@ def test_gqa_past(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2042,8 +2050,8 @@ def test_gqa_interactive_one_batch(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2053,8 +2061,8 @@ def test_gqa_interactive_one_batch(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed,