diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 9cbeb161f4c7e..1cc0dd680f0ff 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -502,6 +502,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp ${MLAS_SRC_DIR}/cast_kernel_neon.cpp ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp @@ -517,6 +518,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc index 13748b43b1ae6..7531d63fb5fc8 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc @@ -202,6 +202,12 @@ Status NchwcConv::Compute(OpKernelContext* context) const { } } +#if defined(__aarch64__) && defined(__linux__) + const bool use_bf16 = use_fastmath_mode_; +#else + const bool use_bf16 = false; +#endif + MlasNchwcConv( X_shape.GetDims().data(), kernel_shape.data(), @@ -216,7 +222,8 @@ Status NchwcConv::Compute(OpKernelContext* context) const { y_data.data(), &activation_, Sum == nullptr, - context->GetOperatorThreadPool()); + context->GetOperatorThreadPool(), + use_bf16); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.h b/onnxruntime/contrib_ops/cpu/nchwc_ops.h index 4827d70489674..169eecdeaa02f 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.h +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.h @@ -7,6 +7,7 @@ #include "core/framework/op_kernel.h" #include "core/providers/cpu/nn/conv_attributes.h" #include "core/providers/cpu/nn/pool.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "contrib_ops/cpu/fused_activation.h" namespace onnxruntime { @@ -43,6 +44,10 @@ class NchwcConv final : public OpKernel { public: NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); +#if defined(__aarch64__) && defined(__linux__) + auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16); + use_fastmath_mode_ = (config_ops == "1") && MlasBf16AccelerationSupported(); +#endif } Status Compute(OpKernelContext* context) const override; @@ -51,6 +56,9 @@ class NchwcConv final : public OpKernel { ConvAttributes conv_attrs_; MLAS_ACTIVATION activation_; +#if defined(__aarch64__) && defined(__linux__) + bool use_fastmath_mode_{false}; +#endif }; class NchwcPoolBase : public PoolBase { diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index be9c997a93ba9..09a0c6de696b4 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1242,7 +1242,8 @@ MlasNchwcConv( float* Output, const MLAS_ACTIVATION* Activation, bool ZeroMode, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + bool UseBf16 = false ); void @@ -1965,6 +1966,7 @@ struct MLAS_SBGEMM_DATA_PARAMS { const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ + bool ZeroMode = true; /**< true: C = A*B, false: C += A*B */ }; /** diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 6f6dc7281ff12..e75ca3dc90e60 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -967,6 +967,9 @@ extern "C" { MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon; MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon; +#if defined(__aarch64__) && defined(__linux__) + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseBf16KernelNeon; +#endif MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelNeon; MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelNeon; MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelNeon; @@ -1372,6 +1375,9 @@ struct MLAS_PLATFORM { MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; +#if defined(__aarch64__) && defined(__linux__) + MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseBf16Kernel; +#endif MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; uint32_t NchwcBlockSize; #endif diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index e80438ad99e2c..b913b1c3b8c26 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -575,6 +575,9 @@ Return Value: this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon; +#if defined(__aarch64__) && defined(__linux__) + this->ConvPointwiseBf16Kernel = MlasConvPointwiseBf16KernelNeon; +#endif this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelNeon; this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelNeon; this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelNeon; diff --git a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp new file mode 100644 index 0000000000000..1a9949983c3ee --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp @@ -0,0 +1,110 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sbconv_kernel_neon.cpp + +Abstract: + + This module implements bfloat16 precision convolution kernels for ARM NEON. + +--*/ + +#if defined(MLAS_USE_ARM_NEON_NCHWC) && defined(__linux__) + +#include "mlasi.h" +#include "sconv.h" + +constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; + +// +// BF16 Pointwise (1x1) Convolution Kernel using SBGEMM. +// +void MLASCALL +MlasConvPointwiseBf16KernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t InputChannels, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags +) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t InputStrideElements = InputStride / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + + // SBGEMM only adds bias when ZeroMode=true. When accumulating (ZeroMode=false), + // pre-add bias to existing output before the GEMM operations. + if (BiasAddition && AccumulateOutput) { + for (size_t f = 0; f < FilterCount; f++) { + float* output = Output + f * OutputStrideElements; + const float32x4_t b0 = MlasLoadFloat32x4(&Bias[f * BlockSize]); + const float32x4_t b1 = MlasLoadFloat32x4(&Bias[f * BlockSize + 4]); + const float32x4_t b2 = MlasLoadFloat32x4(&Bias[f * BlockSize + 8]); + const float32x4_t b3 = MlasLoadFloat32x4(&Bias[f * BlockSize + 12]); + for (size_t i = 0; i < OutputCount; i++) { + MlasStoreFloat32x4(&output[i * BlockSize], MlasAddFloat32x4(b0, MlasLoadFloat32x4(&output[i * BlockSize]))); + MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasAddFloat32x4(b1, MlasLoadFloat32x4(&output[i * BlockSize + 4]))); + MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasAddFloat32x4(b2, MlasLoadFloat32x4(&output[i * BlockSize + 8]))); + MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasAddFloat32x4(b3, MlasLoadFloat32x4(&output[i * BlockSize + 12]))); + } + } + } + + // Build SBGEMM params for all (filter, input_channel) combinations. + // FilterCount <= 4, InputChannels <= 8, so max 32 elements. + // Bias is set on all elements but SBGEMM only uses it when ZeroMode=true. + MLAS_SBGEMM_DATA_PARAMS gemm_params[32]; + + size_t idx = 0; + for (size_t f = 0; f < FilterCount; f++) { + const float* filter = Filter + f * FilterStrideElements; + float* output = Output + f * OutputStrideElements; + for (size_t ic = 0; ic < InputChannels; ic++, idx++) { + gemm_params[idx].A = Input + ic * InputStrideElements; + gemm_params[idx].B = filter + ic * BlockSize * BlockSize; + gemm_params[idx].C = output; + gemm_params[idx].lda = StrideWidthElements; + gemm_params[idx].ldb = BlockSize; + gemm_params[idx].ldc = BlockSize; + gemm_params[idx].Bias = BiasAddition ? (Bias + f * BlockSize) : nullptr; + gemm_params[idx].AIsfp32 = true; + gemm_params[idx].BIsfp32 = true; + gemm_params[idx].ZeroMode = (ic == 0) && !AccumulateOutput; + gemm_params[idx].OutputProcessor = nullptr; + } + } + + MlasSBGemmBatch(OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr); + + if (ReluActivation) { + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + for (size_t f = 0; f < FilterCount; f++) { + float* output = Output + f * OutputStrideElements; + for (size_t i = 0; i < OutputCount; i++) { + MlasStoreFloat32x4(&output[i * BlockSize], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize]), ZeroVector)); + MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 4]), ZeroVector)); + MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 8]), ZeroVector)); + MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 12]), ZeroVector)); + } + } + } +} + +#endif diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h index de7fd72fad45a..5415cb3dc4406 100644 --- a/onnxruntime/core/mlas/lib/sbgemm.h +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -112,7 +112,7 @@ MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK, template MLAS_FORCEINLINE void -MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor) +MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor, bool InitialZeroMode) { constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; size_t PackedStrideN = Strides.N; @@ -131,7 +131,7 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size // size_t CountK; for (size_t k = 0; k < K; k += CountK) { - bool ZeroMode = (k == 0); + bool ZeroMode = (k == 0) && InitialZeroMode; CountK = std::min(K - k, PackedStrideK); const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; @@ -148,7 +148,7 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size template void -MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor) +MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor, bool InitialZeroMode) { // // Compute the strides to step through slices of the input matrices. @@ -201,7 +201,7 @@ MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_ const float* pbias = ((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart - bool ZeroMode = (k == 0); + bool ZeroMode = (k == 0) && InitialZeroMode; MlasSBGemmKernel(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); } if (PostProcessor != nullptr) { @@ -249,16 +249,17 @@ MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const float* A = (const float*)DataParams->A + RangeStartM * lda; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* bias = DataParams->Bias; + const bool zeroMode = DataParams->ZeroMode; if (!DataParams->BIsfp32) { MlasSBGemmPackedOperation( RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, - lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor + lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor, zeroMode ); } else { const size_t ldb = DataParams->ldb; const float* B = (const float*)DataParams->B + RangeStartN; - MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); + MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor, zeroMode); } } diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index 6f3423a792509..505246841087c 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -53,6 +53,7 @@ struct MLAS_NCHWC_CONV_WORK_BLOCK : MLAS_NCHWC_WORK_BLOCK float* Output; size_t GroupCount; bool ZeroMode; + bool UseBf16; }; // @@ -881,6 +882,11 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; +#if defined(__aarch64__) && defined(__linux__) + if (WorkBlock->UseBf16) { + Kernel = GetMlasPlatform().ConvPointwiseBf16Kernel; + } +#endif #else MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel; #endif @@ -1224,7 +1230,8 @@ MlasNchwcConv( float* Output, const MLAS_ACTIVATION* Activation, bool ZeroMode, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + const bool UseBf16 ) /*++ @@ -1269,6 +1276,8 @@ Routine Description: ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. + UseBf16 - Supplies true to use BF16 for convolutions on supported platforms. + Return Value: None. @@ -1288,6 +1297,7 @@ Return Value: WorkBlock.Bias = Bias; WorkBlock.Activation = Activation; WorkBlock.ZeroMode = ZeroMode; + WorkBlock.UseBf16 = UseBf16; // // Capture the generic shape parameters to the work block. diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp index e5a536eb9e4f0..d8b76407edf08 100644 --- a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp +++ b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp @@ -12,6 +12,14 @@ static size_t Conv2dNchwcRegistLongExecute() { if (GetMlasThreadPool() != nullptr) { count += MlasLongExecuteTests>::RegisterLongExecute(); } +#if defined(__aarch64__) && defined(__linux__) + if (MlasBf16AccelerationSupported()) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (GetMlasThreadPool() != nullptr) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + } +#endif } return count; @@ -25,6 +33,14 @@ static size_t Conv2dNchwcRegistShortExecute() { if (GetMlasThreadPool() != nullptr) { count += Conv2dShortExecuteTest>::RegisterShortExecuteTests(); } +#if defined(__aarch64__) && defined(__linux__) + if (MlasBf16AccelerationSupported()) { + count += Conv2dShortExecuteTest>::RegisterShortExecuteTests(); + if (GetMlasThreadPool() != nullptr) { + count += Conv2dShortExecuteTest>::RegisterShortExecuteTests(); + } + } +#endif } return count; diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h index c125720668381..c1162c8d150c4 100644 --- a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h +++ b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h @@ -8,6 +8,8 @@ template class MlasNchwcConv2DTest : public MlasConv2DTest { protected: + bool UseBf16_ = false; + void MlasConv2D( size_t BatchCount, size_t GroupCount, @@ -131,7 +133,8 @@ class MlasNchwcConv2DTest : public MlasConv2DTest { NchwcOutput, &Activation, true, - MlasConv2DTest::threadpool_); + MlasConv2DTest::threadpool_, + UseBf16_); // // Reorder the output buffer. @@ -224,3 +227,51 @@ class MlasNchwcConv2DTest : public MlasConv2DTest { } } }; + +#if defined(__aarch64__) && defined(__linux__) +template +class MlasNchwcConv2DBf16Test : public MlasNchwcConv2DTest { + public: + MlasNchwcConv2DBf16Test() { this->UseBf16_ = true; } + + static const char* GetTestSuiteName() { + static const std::string suite_name(Threaded ? "Conv2dNchwcBf16_Threaded" : "Conv2dNchwcBf16_SingleThread"); + return suite_name.c_str(); + } + + void ExecuteLong() override { + // BF16 pointwise tests (1x1 kernel, no padding, InputChannels >= BlockSize) + for (unsigned ic : {32u, 64u, 128u}) { + for (unsigned fc : {32u, 64u, 128u}) { + for (unsigned sz : {28u, 14u, 7u}) { + TestBf16(1, 1, ic, sz, sz, fc, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1); + TestBf16(1, 1, ic, sz, sz, fc, 1, 1, 0, 0, 0, 0, 1, 1, 2, 2); + TestBf16(4, 1, ic, sz, sz, fc, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1); + } + } + } + } + + private: + void TestBf16(size_t B, size_t G, size_t IC, size_t IH, size_t IW, size_t FC, + size_t KH, size_t KW, size_t p0, size_t p1, size_t p2, size_t p3, + size_t DH, size_t DW, size_t SH, size_t SW) { + size_t OH = (IH + p0 + p2 - DH * (KH - 1) - 1) / SH + 1; + size_t OW = (IW + p1 + p3 - DW * (KW - 1) - 1) / SW + 1; + size_t OutputElements = B * G * FC * OH * OW; + + const float* Input = MlasConv2DTest::BufferInput.GetBuffer(B * G * IC * IH * IW); + const float* Filter = MlasConv2DTest::BufferFilter.GetBuffer(G * FC * IC * KH * KW); + const float* Bias = MlasConv2DTest::BufferBias.GetBuffer(G * FC); + float* Output = MlasConv2DTest::BufferOutput.GetBuffer(OutputElements); + float* OutputRef = MlasConv2DTest::BufferOutputReference.GetBuffer(OutputElements); + + this->MlasConv2D(B, G, IC, IH, IW, FC, KH, KW, p0, p1, p2, p3, DH, DW, SH, SW, OH, OW, Input, Filter, Bias, Output); + MlasConv2DTest::ReferenceConv2D(B, G, IC, IH, IW, FC, KH, KW, p0, p1, DH, DW, SH, SW, OH, OW, Input, Filter, Bias, OutputRef); + + for (size_t i = 0; i < OutputElements; i++) { + ASSERT_TRUE(CloseEnough(Output[i], OutputRef[i])) << " @" << i << " got " << Output[i] << " expected " << OutputRef[i]; + } + } +}; +#endif