diff --git a/onnxruntime/core/framework/transpose_helper.cc b/onnxruntime/core/framework/transpose_helper.cc index 38f68215a0484..ce713912b7ecf 100644 --- a/onnxruntime/core/framework/transpose_helper.cc +++ b/onnxruntime/core/framework/transpose_helper.cc @@ -22,7 +22,8 @@ struct has_mlas_transpose : std::true_type {}; template typename std::enable_if::value, void>::type SimpleTransposeSingleAxisOutwards( const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop, - int64_t writes_per_writer_per_loop) { + int64_t writes_per_writer_per_loop, concurrency::ThreadPool* tp = nullptr) { + ORT_UNUSED_PARAMETER(tp); const T* end; for (int64_t l = 0; l < num_loops; ++l) { T* output_for_first_writer = output_data; @@ -48,10 +49,10 @@ typename std::enable_if::value, void>::type SimpleTranspo template typename std::enable_if::value, void>::type SimpleTransposeSingleAxisOutwards( const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop, - int64_t writes_per_writer_per_loop) { + int64_t writes_per_writer_per_loop, concurrency::ThreadPool* tp = nullptr) { for (int64_t l = 0; l < num_loops; ++l) { MlasTranspose(input_data, output_data, static_cast(writes_per_writer_per_loop), - static_cast(num_writers)); + static_cast(num_writers), tp); input_data += writes_per_loop; output_data += writes_per_loop; } @@ -82,25 +83,25 @@ void TransposeSingleAxisOutwards(gsl::span permutations, const Ten switch (bytes_per_write) { case (sizeof(uint8_t)): { SimpleTransposeSingleAxisOutwards(input_data, output_data, num_loops, num_writers, writes_per_loop, - writes_per_writer_per_loop); + writes_per_writer_per_loop, tp); break; } case (sizeof(uint16_t)): { SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), reinterpret_cast(output_data), num_loops, num_writers, - writes_per_loop, writes_per_writer_per_loop); + writes_per_loop, writes_per_writer_per_loop, tp); break; } case (sizeof(uint32_t)): { SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), reinterpret_cast(output_data), num_loops, num_writers, - writes_per_loop, writes_per_writer_per_loop); + writes_per_loop, writes_per_writer_per_loop, tp); break; } case (sizeof(uint64_t)): { SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), reinterpret_cast(output_data), num_loops, num_writers, - writes_per_loop, writes_per_writer_per_loop); + writes_per_loop, writes_per_writer_per_loop, tp); break; } default: { @@ -125,7 +126,8 @@ void TransposeSingleAxisOutwards(gsl::span permutations, const Ten template typename std::enable_if::value, void>::type SimpleTransposeSingleAxisInwards( const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop, - int64_t reads_per_reader_per_loop) { + int64_t reads_per_reader_per_loop, concurrency::ThreadPool* tp = nullptr) { + ORT_UNUSED_PARAMETER(tp); T* end; for (int64_t l = 0; l < num_loops; ++l) { const T* input_for_first_reader = input_data; @@ -150,10 +152,10 @@ typename std::enable_if::value, void>::type SimpleTranspo template typename std::enable_if::value, void>::type SimpleTransposeSingleAxisInwards( const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop, - int64_t reads_per_reader_per_loop) { + int64_t reads_per_reader_per_loop, concurrency::ThreadPool* tp = nullptr) { for (int64_t l = 0; l < num_loops; ++l) { MlasTranspose(input_data, output_data, static_cast(num_readers), - static_cast(reads_per_reader_per_loop)); + static_cast(reads_per_reader_per_loop), tp); input_data += reads_per_loop; output_data += reads_per_loop; } @@ -162,7 +164,8 @@ typename std::enable_if::value, void>::type SimpleTranspos // moving a single axis inwards where the read/write size is a power of 2 and between 8 and 64 bits. // `input_shape_override` overrides the shape of `input` for compute purposes. void TransposeSingleAxisInwards(gsl::span permutations, const Tensor& input, Tensor& output, - size_t from, size_t to, const TensorShape* input_shape_override = nullptr) { + size_t from, size_t to, const TensorShape* input_shape_override = nullptr, + concurrency::ThreadPool* tp = nullptr) { ORT_UNUSED_PARAMETER(permutations); const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape(); @@ -184,25 +187,25 @@ void TransposeSingleAxisInwards(gsl::span permutations, const Tens switch (bytes_per_read) { case (sizeof(uint8_t)): { SimpleTransposeSingleAxisInwards(input_data, output_data, num_loops, num_readers, reads_per_loop, - reads_per_reader_per_loop); + reads_per_reader_per_loop, tp); break; } case (sizeof(uint16_t)): { SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), reinterpret_cast(output_data), num_loops, num_readers, reads_per_loop, - reads_per_reader_per_loop); + reads_per_reader_per_loop, tp); break; } case (sizeof(uint32_t)): { SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), reinterpret_cast(output_data), num_loops, num_readers, reads_per_loop, - reads_per_reader_per_loop); + reads_per_reader_per_loop, tp); break; } case (sizeof(uint64_t)): { SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), reinterpret_cast(output_data), num_loops, num_readers, reads_per_loop, - reads_per_reader_per_loop); + reads_per_reader_per_loop, tp); break; } default: { @@ -236,7 +239,7 @@ void SingleAxisTranspose(gsl::span permutations, const Tensor& inp if (from > to) { TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override, tp); } else { - TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override); + TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override, tp); } } @@ -309,4 +312,4 @@ bool IsTransposeMovingSingleAxis(gsl::span permutations, size_t& f return single_axis_moved; } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 28ae64c4d5b3e..26ba5959d34d8 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1030,49 +1030,15 @@ MlasComputeTanh( // Transpose routines. // +template void MLASCALL MlasTranspose( - const uint8_t* Input, - uint8_t* Output, - size_t M, - size_t N - ); - -void -MLASCALL -MlasTranspose( - const int8_t* Input, - int8_t* Output, - size_t M, - size_t N - ); - -void -MLASCALL -MlasTranspose( - const uint16_t* Input, - uint16_t* Output, - size_t M, - size_t N - ); - -void -MLASCALL -MlasTranspose( - const uint32_t* Input, - uint32_t* Output, + const DataType* Input, + DataType* Output, size_t M, - size_t N - ); - -void -MLASCALL -MlasTranspose( - const float* Input, - float* Output, - size_t M, - size_t N + size_t N, + MLAS_THREADPOOL* ThreadPool ); // @@ -1780,20 +1746,22 @@ MlasConvDepthwise( MLAS_HALF_GEMM_POSTPROCESSOR* PostProc ); - inline void MlasTranspose( const MLAS_FP16* Input, MLAS_FP16* Output, size_t M, - size_t N + size_t N, + MLAS_THREADPOOL* ThreadPool ) { MlasTranspose( reinterpret_cast(Input), reinterpret_cast(Output), - M, N); + M, + N, + ThreadPool); } diff --git a/onnxruntime/core/mlas/lib/transpose.cpp b/onnxruntime/core/mlas/lib/transpose.cpp index a758a0e59fb4f..e7efbb5416c79 100644 --- a/onnxruntime/core/mlas/lib/transpose.cpp +++ b/onnxruntime/core/mlas/lib/transpose.cpp @@ -16,6 +16,20 @@ Module Name: #include "mlasi.h" +// +// Define the parameters to execute segments of a transpose operation on worker +// threads. +// + +template +struct MLAS_TRANPOSE_WORK_BLOCK { + ptrdiff_t ThreadCountM; + const ElementType* Input; + ElementType* Output; + size_t M; + size_t N; +}; + #if defined(MLAS_SSE2_INTRINSICS) MLAS_FORCEINLINE @@ -541,51 +555,71 @@ MlasTranspose8xNVector( MlasTranspose4xNVector(&Input[InputStride * 4], InputStride, &Output[OutputStride * 4], OutputStride); } +template void MLASCALL -MlasTranspose( - const uint32_t* Input, - uint32_t* Output, - size_t M, - size_t N - ) +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId +); /*++ Routine Description: - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). + This routine is invoked from a worker thread to execute a segment of a transpose Arguments: - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. + Context - Supplies the pointer to the context for the threaded operation. - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. + ThreadId - Supplies the current index of the threaded operation. Return Value: None. --*/ + +template<> +void +MLASCALL +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId + ) { - size_t n = N; + const auto* WorkBlock = (MLAS_TRANPOSE_WORK_BLOCK*)Context; + + // + // Partition the operation along the M dimension. + // + + size_t IndexM; + size_t CountM; + MlasPartitionWork(ThreadId, WorkBlock->ThreadCountM, WorkBlock->M, &IndexM, &CountM); + + // + // Set transpose parameters. + // + + const size_t M = WorkBlock->M; + const size_t N = WorkBlock->N; + + const uint32_t* Input = WorkBlock->Input + IndexM * N; + uint32_t* Output = WorkBlock->Output + IndexM; // // Transpose elements from the input matrix to the output matrix 4 columns // at a time. // + size_t n = N; + while (n >= 4) { const uint32_t* s = Input; uint32_t* d = Output; - size_t m = M; + size_t m = CountM; #if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) || \ defined(MLAS_LSX_INTRINSICS) @@ -624,7 +658,7 @@ Return Value: const uint32_t* s = Input; uint32_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 4) { @@ -650,68 +684,46 @@ Return Value: } } +template<> void MLASCALL -MlasTranspose( - const float* Input, - float* Output, - size_t M, - size_t N +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId ) { - MlasTranspose( - reinterpret_cast(Input), - reinterpret_cast(Output), - M, - N); -} - + const auto* WorkBlock = (MLAS_TRANPOSE_WORK_BLOCK*)Context; -void -MLASCALL -MlasTranspose( - const uint16_t* Input, - uint16_t* Output, - size_t M, - size_t N - ) -/*++ - -Routine Description: - - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. + // + // Partition the operation along the M dimension. + // - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. + size_t IndexM; + size_t CountM; + MlasPartitionWork(ThreadId, WorkBlock->ThreadCountM, WorkBlock->M, &IndexM, &CountM); -Return Value: + // + // Set transpose parameters. + // - None. + const size_t M = WorkBlock->M; + const size_t N = WorkBlock->N; ---*/ -{ - size_t n = N; + const uint16_t* Input = WorkBlock->Input + IndexM * N; + uint16_t* Output = WorkBlock->Output + IndexM; // // Transpose elements from the input matrix to the output matrix 4 columns // at a time. // + size_t n = N; + while (n >= 4) { const uint16_t* s = Input; uint16_t* d = Output; - size_t m = M; + size_t m = CountM; #if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) @@ -749,7 +761,7 @@ Return Value: const uint16_t* s = Input; uint16_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 4) { @@ -775,52 +787,47 @@ Return Value: } } - +template<> void MLASCALL -MlasTranspose( - const uint8_t* Input, - uint8_t* Output, - size_t M, - size_t N +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId ) -/*++ - -Routine Description: - - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. +{ + const auto* WorkBlock = (MLAS_TRANPOSE_WORK_BLOCK*)Context; - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. + // + // Partition the operation along the M dimension. + // - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. + size_t IndexM; + size_t CountM; + MlasPartitionWork(ThreadId, WorkBlock->ThreadCountM, WorkBlock->M, &IndexM, &CountM); -Return Value: + // + // Set transpose parameters. + // - None. + const size_t M = WorkBlock->M; + const size_t N = WorkBlock->N; ---*/ -{ - size_t n = N; + const uint8_t* Input = WorkBlock->Input + IndexM * N; + uint8_t* Output = WorkBlock->Output + IndexM; // // Transpose elements from the input matrix to the output matrix 8 columns // at a time. // + + size_t n = N; + #if defined(MLAS_TARGET_POWER) while (n >= 16) { const uint8_t* s = Input; uint8_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 16) { MlasTranspose16x16Block(s, N, d, M); @@ -848,7 +855,7 @@ Return Value: const uint8_t* s = Input; uint8_t* d = Output; - size_t m = M; + size_t m = CountM; #if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) @@ -886,7 +893,7 @@ Return Value: const uint8_t* s = Input; uint8_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 8) { @@ -912,17 +919,140 @@ Return Value: } } +template void MLASCALL MlasTranspose( + const DataType* Input, + DataType* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ) +/*++ + +Routine Description: + + This routine transposes the input matrix (M rows by N columns) to the + output matrix (N rows by M columns). + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + M - Supplies the number of rows for the input matrix and the number of + columns for the output matrix. + + N - Supplies the number of columns for the input matrix and the number of + rows for the output matrix. + + ThreadPool - Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + +Return Value: + + None. + +--*/ +{ + MLAS_TRANPOSE_WORK_BLOCK WorkBlock; + + // + // Capture the transpose parameters to the work block. + // + + WorkBlock.Input = Input; + WorkBlock.Output = Output; + WorkBlock.M = M; + WorkBlock.N = N; + + // + // Compute the number of target threads given the complexity of the transpose + // operation. Limit the number of threads to the number of rows and try to + // keep each thread processing a minimum number of elements before using + // another thread. + // + + ptrdiff_t ThreadCountM = MlasGetMaximumThreadCount(ThreadPool); + + if (size_t(ThreadCountM) > M) { + ThreadCountM = ptrdiff_t(M); + } + + WorkBlock.ThreadCountM = ThreadCountM; + + MlasExecuteThreaded(MlasTransposeThreaded, &WorkBlock, ThreadCountM, ThreadPool); +} + +template +void +MLASCALL +MlasTranspose( + const uint32_t* Input, + uint32_t* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ); + +template +void +MLASCALL +MlasTranspose( + const uint16_t* Input, + uint16_t* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ); + +template +void +MLASCALL +MlasTranspose( + const uint8_t* Input, + uint8_t* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ); + +template<> +void +MLASCALL +MlasTranspose( const int8_t* Input, int8_t* Output, size_t M, - size_t N) + size_t N, + MLAS_THREADPOOL* ThreadPool + ) { MlasTranspose( reinterpret_cast(Input), reinterpret_cast(Output), M, - N); + N, + ThreadPool); +} + +template<> +void +MLASCALL +MlasTranspose( + const float* Input, + float* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ) +{ + MlasTranspose( + reinterpret_cast(Input), + reinterpret_cast(Output), + M, + N, + ThreadPool); } diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index 37db095e92570..b8203f18b1fbb 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -416,7 +416,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { Xdata, static_cast(transpose_input_buffer.get()), static_cast(C), - static_cast(input_image_size)); + static_cast(input_image_size), + thread_pool); input_data = static_cast(transpose_input_buffer.get()); output_data = static_cast(transpose_output_buffer.get()); add_src = nullptr; @@ -573,7 +574,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { output_data, Ydata, static_cast(output_image_size), - static_cast(M)); + static_cast(M), + thread_pool); if (sum_data != nullptr) { MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(activation_, sum_data); proc.Process(Ydata, 0, 0, static_cast(M), diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_pool.cc b/onnxruntime/core/providers/cpu/fp16/fp16_pool.cc index 7c1e05f7ce277..a7bcb3bf9d155 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_pool.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_pool.cc @@ -156,7 +156,8 @@ Status PoolFp16::Compute(OpKernelContext* context) const { Xdata, static_cast(transpose_input_buffer.get()), static_cast(C), - static_cast(input_image_size)); + static_cast(input_image_size), + thread_pool); input_data = static_cast(transpose_input_buffer.get()); output_data = static_cast(transpose_output_buffer.get()); } @@ -206,7 +207,8 @@ Status PoolFp16::Compute(OpKernelContext* context) const { output_data, Ydata, static_cast(output_image_size), - static_cast(C)); + static_cast(C), + thread_pool); } Xdata += input_image_size * C; Ydata += output_image_size * C; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index f0c1b0b409831..e6a80f0318ad8 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -78,7 +78,7 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, Alloca for (int64_t group_id = 0; group_id < conv_transpose_attrs_.group; ++group_id) { MlasTranspose(tensor.Data() + (group_id * N * K), ((float*)packed_filter_data) + (group_id * packed_elements_per_group), - K, N); + K, N, nullptr); } bool share_prepacked_weights = (prepacked_weights != nullptr); diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index 7797cbe678bd4..01dd62d9e186d 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -787,7 +787,8 @@ Status QLinearConv::Compute(OpKernelContext* context) const { Xdata, static_cast(transpose_input_buffer.get()), static_cast(C), - static_cast(input_image_size)); + static_cast(input_image_size), + thread_pool); input_data = static_cast(transpose_input_buffer.get()); output_data = static_cast(transpose_output_buffer.get()); } @@ -997,7 +998,8 @@ Status QLinearConv::Compute(OpKernelContext* context) const { output_data, Ydata, static_cast(output_image_size), - static_cast(M)); + static_cast(M), + thread_pool); } Xdata += X_offset; diff --git a/onnxruntime/core/quantization/quantization.h b/onnxruntime/core/quantization/quantization.h index 9acdfa6d86ccf..70e89af5ee653 100644 --- a/onnxruntime/core/quantization/quantization.h +++ b/onnxruntime/core/quantization/quantization.h @@ -195,7 +195,7 @@ inline uint8_t* TransPoseInputData(const uint8_t* input, TensorShape outputshape{static_cast(M), static_cast(N)}; buffer_holder.emplace(DataTypeImpl::GetType(), outputshape, allocator); uint8_t* output = buffer_holder->MutableData(); - MlasTranspose(input, output, M, N); + MlasTranspose(input, output, M, N, nullptr); return output; } diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index f75002f715154..907f944be69e3 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -117,7 +117,7 @@ class MlasBlockwiseQdqTest : public MlasTestBase { MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, columnwise, rows, columns, threadpool_ptr); - MlasTranspose(dequant_buf, transposed, columns, rows); + MlasTranspose(dequant_buf, transposed, columns, rows, threadpool_ptr); uint8_t* o_elements = OutputElements.GetBuffer(q_rows * q_cols, true); float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); diff --git a/onnxruntime/test/mlas/unittest/test_transpose.cpp b/onnxruntime/test/mlas/unittest/test_transpose.cpp index 8fa98411a21ab..725aeef6ea094 100644 --- a/onnxruntime/test/mlas/unittest/test_transpose.cpp +++ b/onnxruntime/test/mlas/unittest/test_transpose.cpp @@ -3,12 +3,13 @@ #include "test_util.h" -template +template class MlasTransposeTest : public MlasTestBase { private: MatrixGuardBuffer BufferInput; MatrixGuardBuffer BufferOutput; MatrixGuardBuffer BufferOutputReference; + MLAS_THREADPOOL* threadpool_; void Test(size_t M, size_t N) { @@ -16,7 +17,7 @@ class MlasTransposeTest : public MlasTestBase { ElementType* Output = BufferOutput.GetBuffer(M * N); ElementType* OutputReference = BufferOutputReference.GetBuffer(M * N); - MlasTranspose(Input, Output, M, N); + MlasTranspose(Input, Output, M, N, threadpool_); ReferenceTranspose(Input, OutputReference, M, N); ASSERT_EQ(memcmp(Output, OutputReference, M * N * sizeof(ElementType)), 0) << " [" << M << "," << N << "]"; @@ -32,10 +33,14 @@ class MlasTransposeTest : public MlasTestBase { public: static const char* GetTestSuiteName() { - static const std::string suite_name = std::string("Transpose_Size") + std::to_string(int(sizeof(ElementType))); + static const std::string suite_name = std::string("Transpose_Type:") + + typeid(ElementType).name() + + std::string(Threaded ? "_Threaded" : "_SingleThread"); return suite_name.c_str(); } + MlasTransposeTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} + void ExecuteShort(void) override { for (size_t m = 1; m <= 32; m++) { for (size_t n = 1; n <= 32; n++) { @@ -48,9 +53,14 @@ class MlasTransposeTest : public MlasTestBase { static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } return count; });