Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade XNNPACK to latest version #22012

Merged
merged 27 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
72dce41
upgrade xnn
invalid-email-address Sep 6, 2024
71411f2
v1.0.183
invalid-email-address Sep 6, 2024
2c518be
update
invalid-email-address Sep 6, 2024
baca5c2
merge with main
invalid-email-address Sep 6, 2024
92eb0cc
update patch
invalid-email-address Sep 8, 2024
7fcd9a5
add kleidiai for builing mac arm
invalid-email-address Sep 8, 2024
3c8853d
download kleidiai
invalid-email-address Sep 8, 2024
c41b40d
one message
invalid-email-address Sep 8, 2024
679adb1
try kleidiai package
invalid-email-address Sep 8, 2024
aa58926
update
invalid-email-address Sep 8, 2024
c31b779
fix lint issue
invalid-email-address Sep 9, 2024
d6e54b7
kleidiai only for arm64
invalid-email-address Sep 9, 2024
c46fc52
update cgmanifest
invalid-email-address Sep 9, 2024
120fc59
update resize op
invalid-email-address Sep 10, 2024
2090dbb
no cache in iOS
invalid-email-address Sep 10, 2024
f35c866
update
invalid-email-address Sep 10, 2024
2936f7c
Revert "no cache in iOS"
invalid-email-address Sep 10, 2024
1c0f420
rm kleidia test program
invalid-email-address Sep 11, 2024
93790c8
disable kleidiai test and benchmark
invalid-email-address Sep 14, 2024
465c89b
add microkernes-prod
Sep 14, 2024
57c246c
add ORT_TARGET_PROCESSOR
invalid-email-address Sep 16, 2024
5b544e4
VS Platform str
mszhanyi Sep 16, 2024
fcedf50
update
mszhanyi Sep 16, 2024
256a236
update as comments
mszhanyi Sep 16, 2024
510b6d0
Update cmake/external/xnnpack.cmake
mszhanyi Sep 17, 2024
cf4cf08
update comments
mszhanyi Sep 17, 2024
037bf82
Merge branch 'zhanyi/upgradexnn' of https://github.com/microsoft/onnx…
mszhanyi Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "0da379fc4808f9601faef392352018c741c0f297",
"commitHash": "309b75c9e56e0a674bf78d59872ce131f814dfb6",
"repositoryUrl": "https://github.com/google/XNNPACK.git"
},
"comments": "googlexnnpack"
Expand Down
4 changes: 3 additions & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34
google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177
google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752
googletest;https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip;9d2d0af8d77ac726ea55d44a8fa727ec98311349
googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73
#xnnpack 2024.09.04
googlexnnpack;https://github.com/google/XNNPACK/archive/309b75c9e56e0a674bf78d59872ce131f814dfb6.zip;39FA5259EAEACE0547284B63D5CEDC4F05553F5A
json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c
microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
Expand Down Expand Up @@ -60,3 +61,4 @@ composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/arch
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029
dawn;https://github.com/google/dawn/archive/511eb80847afe6bded34ec491a38d5d78ba2d604.zip;c493f5aca5586f6634e25d0121c85df71189fb99
kleidiai;https://gitlab.arm.com/kleidi/kleidiai/-/archive/v0.2.0/kleidiai-v0.2.0.zip;B1E3173992FD91F20DB904AB77D6E901778C2681
62 changes: 60 additions & 2 deletions cmake/external/xnnpack.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ set(FP16_BUILD_TESTS OFF CACHE INTERNAL "")
set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "")
set(PTHREADPOOL_BUILD_TESTS OFF CACHE INTERNAL "")
set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE INTERNAL "")
set(KLEIDIAI_BUILD_TESTS OFF CACHE INTERNAL "")
set(KLEIDIAI_BUILD_BENCHMARK OFF CACHE INTERNAL "")

if(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*")
set(XNNPACK_USE_SYSTEM_LIBS OFF)
Expand All @@ -30,15 +32,71 @@ set(FXDIV_SOURCE_DIR ${fxdiv_SOURCE_DIR})
FetchContent_Declare(pthreadpool URL ${DEP_URL_pthreadpool} URL_HASH SHA1=${DEP_SHA1_pthreadpool})
onnxruntime_fetchcontent_makeavailable(pthreadpool)

# --- Determine target processor
# Why ORT_TARGET_PROCESSOR is only for XNNPACK
# So far, only Onnxruntime + XNNPack only allow one target processor.
# And we support Mac universal package, so,
# CMAKE_OSX_ARCHITECTURES_COUNT greater than 1 is allowed in other places.
IF(CMAKE_OSX_ARCHITECTURES)
LIST(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_COUNT)
IF(CMAKE_OSX_ARCHITECTURES_COUNT GREATER 1)
MESSAGE(STATUS "Building ONNX Runtime with XNNPACK and multiple OSX architectures is not supported. Got:(${CMAKE_OSX_ARCHITECTURES}). "
"Please specify a single architecture in CMAKE_OSX_ARCHITECTURES and re-configure. ")
ENDIF()
IF(NOT CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64|arm64e|arm64_32)$")
MESSAGE(FATAL_ERROR "Unrecognized CMAKE_OSX_ARCHITECTURES value \"${CMAKE_OSX_ARCHITECTURES}\"")
ENDIF()
SET(ORT_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}")
ADD_COMPILE_OPTIONS("-Wno-shorten-64-to-32")
ELSEIF(CMAKE_GENERATOR MATCHES "^Visual Studio " AND CMAKE_GENERATOR_PLATFORM)
IF(CMAKE_GENERATOR_PLATFORM MATCHES "^Win32")
SET(ORT_TARGET_PROCESSOR "x86")
ELSEIF(CMAKE_GENERATOR_PLATFORM MATCHES "^x64")
SET(ORT_TARGET_PROCESSOR "x86_64")
ELSEIF(CMAKE_GENERATOR_PLATFORM MATCHES "^ARM64")
SET(ORT_TARGET_PROCESSOR "arm64")
ELSEIF(CMAKE_GENERATOR_PLATFORM MATCHES "^ARM64EC")
SET(ORT_TARGET_PROCESSOR "arm64")
ELSE()
MESSAGE(FATAL_ERROR "Unsupported Visual Studio architecture \"${CMAKE_GENERATOR_PLATFORM}\"")
ENDIF()
ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^i[3-7]86$")
SET(ORT_TARGET_PROCESSOR "x86")
ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64")
SET(ORT_TARGET_PROCESSOR "x86_64")
ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]")
SET(ORT_TARGET_PROCESSOR "arm")
ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
SET(ORT_TARGET_PROCESSOR "arm64")
ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "ppc64le")
SET(ORT_TARGET_PROCESSOR "ppc64")
ELSEIF(NOT ORT_TARGET_PROCESSOR MATCHES "^(x86(_64)?|arm64|riscv(32|64|128)|Hexagon|ppc64)$")
SET(ORT_TARGET_PROCESSOR "${CMAKE_SYSTEM_PROCESSOR}")
ELSE()
MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_PROCESSOR value \"${CMAKE_SYSTEM_PROCESSOR}\"")
ENDIF()
MESSAGE(STATUS "Building for ORT_TARGET_PROCESSOR: ${ORT_TARGET_PROCESSOR}")

# KleidiAI is only used in Arm64 platform and not supported by MSVC, the details can be seen in
# https://github.com/google/XNNPACK/blob/3b3f7b8a6668f6ab3b6ce33b9f1d1fce971549d1/CMakeLists.txt#L206C82-L206C117
mszhanyi marked this conversation as resolved.
Show resolved Hide resolved
if(ORT_TARGET_PROCESSOR MATCHES "^arm64.*" AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC")
FetchContent_Declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai})
onnxruntime_fetchcontent_makeavailable(kleidiai)
set(KLEIDIAI_SOURCE_DIR ${kleidiai_SOURCE_DIR})
endif()


FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack}
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch
)
onnxruntime_fetchcontent_makeavailable(googlexnnpack)
set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR})
set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include)

set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK pthreadpool)

set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK microkernels-prod pthreadpool)
if(ORT_TARGET_PROCESSOR MATCHES "^arm64.*" AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC")
list(APPEND onnxruntime_EXTERNAL_LIBRARIES_XNNPACK kleidiai)
endif()

# the XNNPACK CMake setup doesn't include the WASM kernels so we have to manually set those up
if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
Expand Down
34 changes: 16 additions & 18 deletions cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index dba9b4687..a4345898d 100755
index 1ff85b538..c3ef2183f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -122,7 +122,7 @@ ENDIF()
@@ -253,7 +253,7 @@ ENDIF()
# ---[ Build flags
IF(NOT CMAKE_SYSTEM_NAME)
MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined")
Expand All @@ -11,29 +11,27 @@ index dba9b4687..a4345898d 100755
MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME value \"${CMAKE_SYSTEM_NAME}\"")
ENDIF()
IF(CMAKE_SYSTEM_NAME MATCHES "Windows")
@@ -534,7 +534,12 @@ IF(XNNPACK_BUILD_LIBRARY)
TARGET_LINK_LIBRARIES(operator-utils PRIVATE logging)
TARGET_LINK_LIBRARIES(post-operation PRIVATE logging)
TARGET_LINK_LIBRARIES(subgraph PRIVATE allocator logging memory mutex operators operator-run)
- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph)
@@ -763,7 +763,12 @@ IF(XNNPACK_BUILD_LIBRARY)
TARGET_LINK_LIBRARIES(operator-run PRIVATE xnnpack-base logging)
TARGET_LINK_LIBRARIES(operator-utils PRIVATE xnnpack-base logging)
TARGET_LINK_LIBRARIES(subgraph PRIVATE xnnpack-base allocator logging memory mutex operators operator-run)
- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph)
+ IF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
+ # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake
+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation subgraph)
+ # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake
+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing subgraph)
+ ELSE()
+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph)
+ ENDIF()
+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph)
+ ENDIF()
TARGET_LINK_LIBRARIES(XNNPACK PUBLIC xnnpack-base)
SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES)
ENDIF()
IF(NOT MSVC)
@@ -543,8 +548,9 @@ ENDIF()
@@ -772,7 +777,8 @@ IF(NOT MSVC)
ENDIF()
IF(XNNPACK_TARGET_PROCESSOR STREQUAL "arm")
SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ")
SET_PROPERTY(SOURCE ${PROD_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ")
- SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv6 -mfpu=vfp -munaligned-access ")
- SET_PROPERTY(SOURCE ${PROD_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv6 -mfpu=vfp -munaligned-access ")
+ # set this to armv7-a to workaround build issue. we don't target armv6 so it shouldn't matter
+ SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ")
+ SET_PROPERTY(SOURCE ${PROD_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ")
+ SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ")
SET_PROPERTY(SOURCE ${ALL_NEON_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon ")
SET_PROPERTY(SOURCE ${PROD_NEON_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon ")
SET_PROPERTY(SOURCE ${ALL_NEONFP16_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon-fp16 ")
# GCC requires -mfp16-format=ieee to define __fp16 type, but Clang doesn't support this option at all.
25 changes: 11 additions & 14 deletions onnxruntime/core/providers/xnnpack/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,24 +166,21 @@ Softmax::Softmax(const OpKernelInfo& info) : XnnpackKernel{info} {
if (op_type_ == OpComputeType::op_compute_type_qu8) {
// the order of input tensor, x,x_scale, x_zp, y_scale, y_zp
OpQuantParam quant_param = ParseQuantParamForOp(info, x_dtype, 1);
xstatus = xnn_create_softmax_nc_qu8(channels,
channels,
channels,
quant_param[0].first[0], // x_scale
quant_param[1].second, // y_zp
quant_param[1].first[0], // y_scale
0, // flags,
&p);
xstatus = xnn_create_softmax_nc_qu8(
quant_param[0].first[0], // x_scale, input scale
quant_param[1].second, // y_zp, output zero point
quant_param[1].first[0], // y_scale, output scale
0, // flags,
&p);
} else if (op_type_ == OpComputeType::op_compute_type_fp32) {
xstatus = xnn_create_softmax_nc_f32(channels,
channels,
channels,
0, // flags,
&p);
xstatus = xnn_create_softmax_nc_f32(
0, // flags,
&p);
}

ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_softmax_nc_",
OpTypeToString(op_type_), " failed. Status:", xstatus);
channel_dim_ = channels;
op0_.reset(p);
}

Expand All @@ -205,7 +202,7 @@ Status Softmax::Compute(OpKernelContext* ctx) const {

auto reshape_fn = op_type_ == OpComputeType::op_compute_type_qu8 ? xnn_reshape_softmax_nc_qu8
: xnn_reshape_softmax_nc_f32;
status = reshape_fn(op0_.get(), N, threadpool);
status = reshape_fn(op0_.get(), channel_dim_, channel_dim_, channel_dim_, N, threadpool);

if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_softmax_nc_", OpTypeToString(op_type_),
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/xnnpack/math/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Softmax final : public XnnpackKernel {
int opset_;
OpComputeType op_type_ = OpComputeType::op_compute_type_invalid;
XnnpackOperator op0_;
int64_t channel_dim_;
};
} // namespace xnnpack
} // namespace onnxruntime
8 changes: 3 additions & 5 deletions onnxruntime/core/providers/xnnpack/nn/average_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ namespace onnxruntime {
namespace xnnpack {
namespace {
Status CreateXnnpackKernel(const PoolAttributes& pool_attrs,
int64_t C,
const std::optional<std::pair<float, float>>& clip_min_max,
struct xnn_operator*& p,
const OpQuantParam& quant_param,
Expand All @@ -42,7 +41,6 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs,
input_padding_bottom, input_padding_left,
pooling_height, pooling_width,
stride_height, stride_width,
C, C, C, // channels, input_pixel_stride, output_pixel_stride
foutput_min, foutput_max, flags, &p);
} else if (avgpool_type == OpComputeType::op_compute_type_qu8) {
const float output_scale = quant_param[1].first[0];
Expand All @@ -53,7 +51,6 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs,
input_padding_bottom, input_padding_left,
pooling_height, pooling_width,
stride_height, stride_width,
C, C, C, // channels, input_pixel_stride, output_pixel_stride
quant_param[0].second,
quant_param[0].first[0],
quant_param[1].second,
Expand Down Expand Up @@ -209,7 +206,7 @@ AveragePool::AveragePool(const OpKernelInfo& info)
ORT_THROW("unsupported AveragePool in XnnpackEP, we have FLOAT|UINT8, but got ", stype);
}
struct xnn_operator* p;
auto ret = CreateXnnpackKernel(pool_attrs_, C, clip_min_max_, p,
auto ret = CreateXnnpackKernel(pool_attrs_, clip_min_max_, p,
quant_param, avgpool_type_);
ORT_ENFORCE(ret.IsOK(), ret.ErrorMessage());
op0_.reset(p);
Expand All @@ -222,6 +219,7 @@ Status AveragePool::Compute(OpKernelContext* context) const {
int64_t N = X_shape[0];
int64_t H = X_shape[1];
int64_t W = X_shape[2];
int64_t C = X_shape[3];

// set the N dim to the correct value
TensorShapeVector output_dims{output_dims_};
Expand All @@ -247,7 +245,7 @@ Status AveragePool::Compute(OpKernelContext* context) const {
? xnn_reshape_average_pooling2d_nhwc_f32
: xnn_reshape_average_pooling2d_nhwc_qu8;

auto status = reshape_fn(op0_.get(), N, H, W,
auto status = reshape_fn(op0_.get(), N, H, W, C, C, C,
&workspace_size, &workspace_alignment,
/*output_height_out=*/nullptr, /*output_width_out=*/nullptr,
threadpool);
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/xnnpack/nn/max_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ MaxPool::MaxPool(const OpKernelInfo& info)
pooling_height, pooling_width,
stride_height, stride_width,
dilation_height, dilation_width,
C, C, C, // channels, input_pixel_stride, output_pixel_stride
foutput_min, foutput_max, flags, &p);
} else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
maxpool_type_ = OpComputeType::op_compute_type_qu8;
Expand All @@ -183,7 +182,6 @@ MaxPool::MaxPool(const OpKernelInfo& info)
pooling_height, pooling_width,
stride_height, stride_width,
dilation_height, dilation_width,
C, C, C, // channels, input_pixel_stride, output_pixel_stride
output_min, output_max, flags, &p);
} else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
maxpool_type_ = OpComputeType::op_compute_type_qs8;
Expand All @@ -194,7 +192,6 @@ MaxPool::MaxPool(const OpKernelInfo& info)
pooling_height, pooling_width,
stride_height, stride_width,
dilation_height, dilation_width,
C, C, C, // channels, input_pixel_stride, output_pixel_stride
output_min, output_max, flags, &p);
} else {
auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X_arg.TypeAsProto()));
Expand All @@ -213,6 +210,7 @@ Status MaxPool::Compute(OpKernelContext* context) const {
int64_t N = X_shape[0];
int64_t H = X_shape[1];
int64_t W = X_shape[2];
int64_t C = X_shape[3];

// set the N dim to the correct value
TensorShapeVector output_dims{output_dims_};
Expand All @@ -234,6 +232,7 @@ Status MaxPool::Compute(OpKernelContext* context) const {
}

auto status = reshape_fn(op0_.get(), N, H, W,
C, C, C, // channels, input_pixel_stride, output_pixel_stride
/*output_height_out=*/nullptr, /*output_width_out=*/nullptr,
threadpool);
if (status != xnn_status_success) {
Expand Down
Loading
Loading