From 094616bcce2de04a8e390737af5406b593330243 Mon Sep 17 00:00:00 2001 From: "Huang, Zhiwei" Date: Thu, 6 Apr 2023 11:26:10 +0800 Subject: [PATCH] [Intel GPU]Add 13 Intel GPU kernels (#479) --- backends/intel_gpu/CMakeLists.txt | 24 +- backends/intel_gpu/README.md | 8 +- backends/intel_gpu/kernels/argsort_kernel.cc | 255 +++++++ .../intel_gpu/kernels/assign_value_kernel.cc | 97 +++ backends/intel_gpu/kernels/cast_kernel.cc | 127 ++++ backends/intel_gpu/kernels/compare_kernel.cc | 332 +++++++++ backends/intel_gpu/kernels/dnn_support.hpp | 190 ++++- .../intel_gpu/kernels/elementwise_kernel.cc | 159 +++++ backends/intel_gpu/kernels/full_kernel.cc | 55 ++ backends/intel_gpu/kernels/kernels.h | 30 + backends/intel_gpu/kernels/memcpy_kernel.cc | 104 +++ backends/intel_gpu/kernels/phi_funcs.h | 391 +++++++++++ backends/intel_gpu/kernels/reduce_kernel.cc | 231 ++++++ backends/intel_gpu/kernels/reshape_kernel.cc | 189 +++++ backends/intel_gpu/kernels/slice_kernel.cc | 153 ++++ backends/intel_gpu/kernels/softmax_kernel.cc | 228 ++++++ .../intel_gpu/kernels/transpose_kernel.cc | 97 +++ .../kernels/uniform_random_kernel.cc | 113 +++ backends/intel_gpu/load.sh | 2 +- backends/intel_gpu/runtime/runtime.cc | 20 +- backends/intel_gpu/tests/CMakeLists.txt | 16 +- backends/intel_gpu/tests/test_MNIST_model.py | 91 +++ .../intel_gpu/tests/unittests/CMakeLists.txt | 23 + .../tests/unittests/test_argsort_op.py | 439 ++++++++++++ .../tests/unittests/test_assign_value_op.py | 110 +++ .../intel_gpu/tests/unittests/test_cast_op.py | 132 ++++ .../tests/unittests/test_compare_op.py | 308 ++++++++ .../unittests/test_elementwise_mul_op.py | 240 +++++++ .../tests/unittests/test_fill_constant_op.py | 241 +++++++ .../intel_gpu/tests/unittests/test_mean_op.py | 96 +++ .../tests/unittests/test_memcpy_op.py | 117 ++++ .../tests/unittests/test_reduce_op.py | 439 ++++++++++++ .../tests/unittests/test_reshape_op.py | 352 ++++++++++ .../tests/unittests/test_slice_op.py | 655 ++++++++++++++++++ .../tests/unittests/test_softmax_op.py | 167 +++++ .../tests/unittests/test_transpose_op.py | 406 +++++++++++ .../tests/unittests/test_uniform_random_op.py | 571 +++++++++++++++ 37 files changed, 7165 insertions(+), 43 deletions(-) create mode 100644 backends/intel_gpu/kernels/argsort_kernel.cc create mode 100644 backends/intel_gpu/kernels/assign_value_kernel.cc create mode 100644 backends/intel_gpu/kernels/cast_kernel.cc create mode 100644 backends/intel_gpu/kernels/compare_kernel.cc create mode 100644 backends/intel_gpu/kernels/elementwise_kernel.cc create mode 100644 backends/intel_gpu/kernels/full_kernel.cc create mode 100644 backends/intel_gpu/kernels/kernels.h create mode 100644 backends/intel_gpu/kernels/memcpy_kernel.cc create mode 100644 backends/intel_gpu/kernels/phi_funcs.h create mode 100644 backends/intel_gpu/kernels/reduce_kernel.cc create mode 100644 backends/intel_gpu/kernels/reshape_kernel.cc create mode 100644 backends/intel_gpu/kernels/slice_kernel.cc create mode 100644 backends/intel_gpu/kernels/softmax_kernel.cc create mode 100644 backends/intel_gpu/kernels/transpose_kernel.cc create mode 100644 backends/intel_gpu/kernels/uniform_random_kernel.cc create mode 100644 backends/intel_gpu/tests/test_MNIST_model.py create mode 100644 backends/intel_gpu/tests/unittests/CMakeLists.txt create mode 100644 backends/intel_gpu/tests/unittests/test_argsort_op.py create mode 100644 backends/intel_gpu/tests/unittests/test_assign_value_op.py create mode 100644 backends/intel_gpu/tests/unittests/test_cast_op.py create mode 100755 backends/intel_gpu/tests/unittests/test_compare_op.py create mode 100644 backends/intel_gpu/tests/unittests/test_elementwise_mul_op.py create mode 100644 backends/intel_gpu/tests/unittests/test_fill_constant_op.py create mode 100644 backends/intel_gpu/tests/unittests/test_mean_op.py create mode 100644 backends/intel_gpu/tests/unittests/test_memcpy_op.py create mode 100644 backends/intel_gpu/tests/unittests/test_reduce_op.py create mode 100755 backends/intel_gpu/tests/unittests/test_reshape_op.py create mode 100644 backends/intel_gpu/tests/unittests/test_slice_op.py create mode 100644 backends/intel_gpu/tests/unittests/test_softmax_op.py create mode 100644 backends/intel_gpu/tests/unittests/test_transpose_op.py create mode 100644 backends/intel_gpu/tests/unittests/test_uniform_random_op.py diff --git a/backends/intel_gpu/CMakeLists.txt b/backends/intel_gpu/CMakeLists.txt index 89ddd32f8..696d69a52 100644 --- a/backends/intel_gpu/CMakeLists.txt +++ b/backends/intel_gpu/CMakeLists.txt @@ -30,21 +30,16 @@ endif() find_package(PythonInterp 3.9 REQUIRED) find_package(PythonLibs 3.9 REQUIRED) -if(NOT DEFINED ENV{DPL_ROOT}) - message(FATAL_ERROR "Please set the env var DPL_ROOT!" $ENV{DPL_ROOT}) -endif() - set(DPCPP_COMPIER_PATH "${ONEAPI_PATH}/${ONEAPI_COMPILER_DIR}/linux/bin/dpcpp") set(ONEAPI_SYCL_INCLUDE "${ONEAPI_PATH}/${ONEAPI_COMPILER_DIR}/linux/include/sycl/") set(ONEAPI_SYCL_LIBDIR "${ONEAPI_PATH}/${ONEAPI_COMPILER_DIR}/linux/lib/") set(CMAKE_CXX_COMPILER "${DPCPP_COMPIER_PATH}") -# TODO(Zhiwei35): when using oneDNN, should add oneDNN include dir. if(WITH_ONEDNN) include_directories(${ONEAPI_SYCL_INCLUDE} ${ONEDNN_INC}) else() - include_directories(${ONEAPI_SYCL_INCLUDE} $ENV{DPL_ROOT}/include) + include_directories(${ONEAPI_SYCL_INCLUDE} $ENV{DNNLROOT}/include) endif() set(PLUGIN_NAME "paddle-custom-intel-gpu") @@ -65,19 +60,18 @@ endif() add_definitions(-std=c++14) +# custom kernels +file( + GLOB_RECURSE PLUGIN_SRCS + RELATIVE ${CMAKE_SOURCE_DIR} + kernels/*.cc) +message(STATUS "PLUGIN_SRCS : ${PLUGIN_SRCS}") + # custom runtime -set(PLUGIN_SRCS runtime/runtime.cc) +list(APPEND PLUGIN_SRCS runtime/runtime.cc) add_definitions(-DPADDLE_WITH_CUSTOM_DEVICE) add_definitions(-DPADDLE_WITH_CUSTOM_KERNEL) -# custom kernels -if(WITH_KERNELS) - # TODO(Zhiwei35, to be more general when adding many kernels) - list(APPEND PLUGIN_SRCS kernels/mean_kernel.cc) - message(STATUS "CUSTOM_KERNEL_SRCS : ${CUSTOM_KERNEL_SRCS}") - message(STATUS "PLUGIN_SRCS : ${PLUGIN_SRCS}") -endif() - # build shared library add_library(${PLUGIN_NAME} SHARED ${PLUGIN_SRCS}) diff --git a/backends/intel_gpu/README.md b/backends/intel_gpu/README.md index 16fd0ce05..3540ebb5e 100644 --- a/backends/intel_gpu/README.md +++ b/backends/intel_gpu/README.md @@ -33,7 +33,7 @@ pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu mkdir build && cd build cmake .. -make -j8 +make -j $(nproc) # using pip to install the output pip install dist/paddle_custom_intel_gpu*.whl @@ -42,10 +42,10 @@ pip install dist/paddle_custom_intel_gpu*.whl ## Verification ```bash -# list available hardware backends -python -c "import paddle; print(paddle.device.get_all_custom_device_type())" +# check the plugin status +python -c "import paddle; print('intel_gpu' in paddle.device.get_all_custom_device_type())" # expected output -['intel_gpu'] +True ``` diff --git a/backends/intel_gpu/kernels/argsort_kernel.cc b/backends/intel_gpu/kernels/argsort_kernel.cc new file mode 100644 index 000000000..e169ee124 --- /dev/null +++ b/backends/intel_gpu/kernels/argsort_kernel.cc @@ -0,0 +1,255 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/dnn_support.hpp" +#include "kernels/kernels.h" +#include "kernels/phi_funcs.h" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +namespace gpu { + +template +void Transpose(const phi::Context& ctx, + const phi::DenseTensor& x, + const std::vector& axis, + T* out_data, + const std::vector& out_dims, + int64_t out_numel) { + auto x_dims = x.dims(); + auto x_data = x.data(); + show_kernel("TransposeKernel"); + show_debug("x{dims}=" << x.dims() << " x{rank}=" << x_dims.size() + << " out{dims}=" << out_dims); + + if (out_numel == 0) { + return; + } + auto rank = x_dims.size(); + if (rank == 1) { + memcpy(out_data, x_data, x.numel() * sizeof(T)); + } + PD_CHECK(axis.size() == rank, + "axis.size (%d) must be equal the rank of input (%d).", + axis.size(), + rank); + + std::vector step(out_dims.size(), 1); + for (auto i = out_dims.size() - 1; i > 0; --i) { + step[i - 1] = step[i] * out_dims[i]; + } + + std::vector index(rank, 0); + for (auto i = 0; i < x.numel(); ++i) { + std::vector dst_index(rank, 0); + for (auto j = 0; j < rank; ++j) { + dst_index[j] = index[axis[j]]; + } + out_data[phi::vec_product(dst_index, step)] = x_data[i]; + + index.back()++; + for (auto j = rank - 1; j > 0; --j) { + if (index[j] >= x_dims[j]) { + index[j - 1]++; + index[j] = 0; + } else { + break; + } + } + } +} + +template +void FullSort(int input_height, + int input_width, + int input_dim, + T* input, + T* t_out, + int64_t* t_indices, + bool descending) { + for (int i = 0; i < input_height; ++i) { + std::vector> col_vec; + col_vec.reserve(input_width); + if (input_dim == 1) { + for (int j = 0; j < input_width; ++j) { + col_vec.push_back(std::pair(input[j], j)); + } + } else { + for (int j = 0; j < input_width; ++j) { + col_vec.push_back(std::pair(input[i * input_width + j], j)); + } + } + std::sort(col_vec.begin(), + col_vec.end(), + [&](const std::pair& l, const std::pair& r) { + if (descending) + // TODO(Zhiwei35) comparison with NaN always evaluates to + // false in fast floating point modes and need to enhance + return (std::isnan(static_cast(l.first)) && + !std::isnan(static_cast(r.first))) || + (l.first > r.first); + else + return (!std::isnan(static_cast(l.first)) && + std::isnan(static_cast(r.first))) || + (l.first < r.first); + }); + for (int j = 0; j < input_width; ++j) { + t_out[i * input_width + j] = col_vec[j].first; + t_indices[i * input_width + j] = col_vec[j].second; + } + } +} + +template +void ArgsortKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& input, + int axis, + bool descending, + phi::DenseTensor* output, + phi::DenseTensor* indices) { + auto in_dims = input.dims(); + auto out_dims = output->dims(); + auto out_size = output->numel(); + auto ids_size = indices->numel(); + auto out_mem_size = out_size * sizeof(T); + auto ids_mem_size = ids_size * sizeof(int64_t); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + T* out_data = dev_ctx.template Alloc(output); + int64_t* ids_data = dev_ctx.template Alloc(indices); + + show_kernel("argsort in_dims=" << in_dims << " axis=" << axis << " type=" + << dnn_support::type2String::name() + << " desc=" << descending); + // TODO(Zhiwei35): support argsort with dims >=3 + PD_CHECK(in_dims.size() < 3, "PoC Lenet/Mnist use case only"); + auto* q = static_cast(const_cast(dev_ctx.stream())); + size_t n = 1; + size_t m = in_dims[0]; + + if (in_dims.size() == 2) { + n = in_dims[0]; + m = in_dims[1]; + } + phi::DenseTensor cpu_input; + cpu_input.Resize(std::vector(in_dims)); + cpu_input.set_dtype(input.dtype()); + auto cpu_input_data = dev_ctx.template HostAlloc(&cpu_input); + + auto input_data = input.data(); + q->memcpy(cpu_input_data, input_data, input.memory_size()); + q->wait(); + // cpu implement + phi::DenseTensor cpu_output; + cpu_output.Resize(std::vector(out_dims)); + cpu_output.set_dtype(output->dtype()); + auto cpu_output_dims = cpu_output.dims(); + auto cpu_output_numel = cpu_output.numel(); + auto cpu_output_data = dev_ctx.template HostAlloc(&cpu_output); + + phi::DenseTensor cpu_ids; + cpu_ids.Resize(std::vector(indices->dims())); + cpu_ids.set_dtype(indices->dtype()); + auto cpu_ids_dims = cpu_ids.dims(); + auto cpu_ids_numel = cpu_ids.numel(); + auto cpu_ids_data = dev_ctx.template HostAlloc(&cpu_ids); + // no need transpose + if (axis == -1 || axis + 1 == in_dims.size()) { + const int input_height = n; + const int input_width = m; + FullSort(input_height, + input_width, + in_dims.size(), + cpu_input_data, + cpu_output_data, + cpu_ids_data, + descending); + } else { + // do cpu transpose + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + std::vector trans_dims(in_dims.cbegin(), in_dims.cend()); + for (size_t i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + + phi::DenseTensor trans_inp; + trans_inp.Resize(trans_dims); + auto trans_input_dims = trans_inp.dims(); + auto trans_input_numel = trans_inp.numel(); + auto trans_input_data = dev_ctx.template HostAlloc(&trans_inp); + // do cpu transpose input + Transpose(dev_ctx, + cpu_input, + trans, + trans_input_data, + trans_input_dims, + trans_input_numel); + + const int64_t input_height = trans_dims[0]; + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + phi::DenseTensor cpu_tmp_output; + cpu_tmp_output.Resize(trans_dims); + cpu_tmp_output.set_dtype(output->dtype()); + auto cpu_tmp_output_data = dev_ctx.template HostAlloc(&cpu_tmp_output); + + phi::DenseTensor cpu_tmp_ids; + cpu_tmp_ids.Resize(trans_dims); + cpu_tmp_ids.set_dtype(indices->dtype()); + auto cpu_tmp_ids_data = dev_ctx.template HostAlloc(&cpu_tmp_ids); + + FullSort(input_height, + input_width, + trans_dims.size(), + trans_input_data, + cpu_tmp_output_data, + cpu_tmp_ids_data, + descending); + + Transpose( + dev_ctx, cpu_tmp_ids, trans, cpu_ids_data, cpu_ids_dims, cpu_ids_numel); + // CPU transpose back + Transpose(dev_ctx, + cpu_tmp_output, + trans, + cpu_output_data, + cpu_output_dims, + cpu_output_numel); + } + // copy cpu result to intel gpu + q->memcpy(out_data, cpu_output_data, out_mem_size); + q->memcpy(ids_data, cpu_ids_data, ids_mem_size); + q->wait(); +} // ArgsortKernel + +} // namespace gpu + +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(argsort, + intel_gpu, + ALL_LAYOUT, + custom_kernel::gpu::ArgsortKernel, + float, + double, + int, + int64_t) {} diff --git a/backends/intel_gpu/kernels/assign_value_kernel.cc b/backends/intel_gpu/kernels/assign_value_kernel.cc new file mode 100644 index 000000000..ffab9f8dd --- /dev/null +++ b/backends/intel_gpu/kernels/assign_value_kernel.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/dnn_support.hpp" +#include "kernels/phi_funcs.h" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +template +void AssignValueKernel(const phi::Context& dev_ctx, + const std::vector& shape, + phi::DataType dtype, + const std::vector& values, + phi::DenseTensor* out) { + show_kernel("AssignValue-SYCL, type=" << dnn_support::type2String::name()); + + auto template_dtype = phi::capi::CppTypeToPDType::Type(); + PD_CHECK(dtype == template_dtype, + "Argument dtype mismatch for kernel dtype, " + "argument dtype is %s, kernel dtype is %s.", + dtype, + template_dtype); + auto out_size = values.size(); + out->Resize({static_cast(out_size)}); + auto out_data = dev_ctx.template Alloc(out); + + auto* q = static_cast(dev_ctx.stream()); + + std::vector assign_values; + assign_values.reserve(values.size()); + for (const auto& val : values) { + assign_values.emplace_back(val.to()); + } + q->memcpy(out_data, &assign_values[0], assign_values.size() * sizeof(T)); + q->wait(); + out->Resize(std::vector(shape.cbegin(), shape.cend())); +} + +template +void AssignKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + phi::DenseTensor* out) { + auto out_data = dev_ctx.template Alloc(out); + auto x_data = x.data(); + std::memcpy(out_data, x_data, sizeof(T) * x.numel()); +} + +template +void AssignRawKernel(const phi::Context& dev_ctx, + const paddle::optional& x, + phi::DenseTensor* out) { + show_kernel("AssignRaw-SYCL, type=" << dnn_support::type2String::name()); + + if (x) { + if (!x->initialized()) { + return; + } + auto x_data = x->data(); + auto out_data = dev_ctx.template Alloc(out); + + auto* q = static_cast(dev_ctx.stream()); + q->memcpy(out_data, x_data, x->numel()); + q->wait(); + } +} + +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(assign_value, + intel_gpu, + ALL_LAYOUT, + custom_kernel::AssignValueKernel, + int, + int64_t, + float, + double) {} + +PD_BUILD_PHI_KERNEL(assign_raw, + intel_gpu, + ALL_LAYOUT, + custom_kernel::AssignRawKernel, + int, + int64_t, + float, + double) {} diff --git a/backends/intel_gpu/kernels/cast_kernel.cc b/backends/intel_gpu/kernels/cast_kernel.cc new file mode 100644 index 000000000..4ace71737 --- /dev/null +++ b/backends/intel_gpu/kernels/cast_kernel.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/dnn_support.hpp" +#include "kernels/phi_funcs.h" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +template +void CastKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + phi::DataType out_dtype, + phi::DenseTensor* out) { + show_kernel("Cast-SYCL"); + + auto x_data = x.data(); + out->Resize(x.dims()); + auto numel = x.numel(); + auto* q = static_cast(dev_ctx.stream()); + + switch (out_dtype) { + case phi::DataType::BFLOAT16: { + auto out_data = dev_ctx.template Alloc(out); + q->parallel_for(numel, [=](auto& i) { + out_data[i] = + static_cast(static_cast(x_data[i])); + }); + break; + } + case phi::DataType::FLOAT16: { + auto out_data = dev_ctx.template Alloc(out); + q->parallel_for(numel, [=](auto& i) { + out_data[i] = + static_cast(static_cast(x_data[i])); + }); + break; + } + case phi::DataType::FLOAT32: { + auto out_data = dev_ctx.template Alloc(out); + q->parallel_for(numel, [=](auto& i) { + out_data[i] = static_cast(static_cast(x_data[i])); + }); + break; + } + case phi::DataType::FLOAT64: { + auto out_data = dev_ctx.template Alloc(out); + q->parallel_for(numel, [=](auto& i) { + out_data[i] = static_cast(static_cast(x_data[i])); + }); + break; + } + case phi::DataType::INT8: { + auto out_data = dev_ctx.template Alloc(out); + q->parallel_for(numel, [=](auto& i) { + out_data[i] = static_cast(static_cast(x_data[i])); + }); + break; + } + case phi::DataType::INT16: { + auto out_data = dev_ctx.template Alloc(out); + q->parallel_for(numel, [=](auto& i) { + out_data[i] = static_cast(static_cast(x_data[i])); + }); + break; + } + case phi::DataType::INT32: { + auto out_data = dev_ctx.template Alloc(out); + q->parallel_for(numel, [=](auto& i) { + out_data[i] = static_cast(static_cast(x_data[i])); + }); + break; + } + case phi::DataType::INT64: { + auto out_data = dev_ctx.template Alloc(out); + q->parallel_for(numel, [=](auto& i) { + out_data[i] = static_cast(static_cast(x_data[i])); + }); + break; + } + case phi::DataType::UINT8: { + auto out_data = dev_ctx.template Alloc(out); + q->parallel_for(numel, [=](auto& i) { + out_data[i] = static_cast(static_cast(x_data[i])); + }); + break; + } + case phi::DataType::BOOL: { + auto out_data = dev_ctx.template Alloc(out); + q->parallel_for(numel, [=](auto& i) { + out_data[i] = static_cast(static_cast(x_data[i])); + }); + break; + } + default: + break; + } + q->wait(); +} + +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(cast, + intel_gpu, + ALL_LAYOUT, + custom_kernel::CastKernel, + float, + double, + int, + int64_t, + int16_t, + bool, + int8_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/backends/intel_gpu/kernels/compare_kernel.cc b/backends/intel_gpu/kernels/compare_kernel.cc new file mode 100644 index 000000000..d26ad75a6 --- /dev/null +++ b/backends/intel_gpu/kernels/compare_kernel.cc @@ -0,0 +1,332 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/dnn_support.hpp" +#include "kernels/phi_funcs.h" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +template +void RawCompareKernelSycl(const phi::Context& dev_ctx, + std::string kernel_name, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out, + const F& func, + const FF& float_func) { + show_kernel(kernel_name << "-SYCL type=" + << dnn_support::type2String::name()); + + auto x_dims = x.dims(); + auto y_dims = y.dims(); + auto dst_dims = phi::BroadcastDims(axis, x_dims, y_dims); + + auto x_data = x.data(); + auto y_data = y.data(); + auto out_data = dev_ctx.template Alloc(out); + auto numel = out->numel(); + + auto* q = static_cast(dev_ctx.stream()); + // if float_func == func only func is to be calculated + if (float_func != func && std::is_floating_point::value) { + q->parallel_for(numel, + [=](auto& i) { float_func(x_data, y_data, out_data, i); }); + } else { + q->parallel_for(numel, [=](auto& i) { func(x_data, y_data, out_data, i); }); + } + q->wait(); +} + +template +void RawCompareKernelDNN(const phi::Context& dev_ctx, + std::string kernel_name, + dnnl::algorithm binary_type, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out) { + show_kernel(kernel_name << "-DNN type=" + << dnn_support::type2String::name()); + + void* stream = const_cast(dev_ctx.stream()); + auto* q = static_cast(const_cast(dev_ctx.stream())); + + using tag = dnnl::memory::format_tag; + using dt = dnnl::memory::data_type; + + auto eng = dnnl::sycl_interop::make_engine(q->get_device(), q->get_context()); + auto engine_stream = dnnl::sycl_interop::make_stream(eng, *q); + + dnnl::memory::dims dims_x = x.dims(); + dnnl::memory::dims dims_y = y.dims(); + dnnl::memory::dims dims_out = out->dims(); + + phi::update_broadcast(dims_x, dims_y, axis); + + auto md_x = dnnl::memory::desc( + dims_x, dnn_support::toDnnType::type, dnn_support::dims2Tag(dims_x)); + + auto md_y = dnnl::memory::desc( + dims_y, dnn_support::toDnnType::type, dnn_support::dims2Tag(dims_y)); + auto md_out = dnnl::memory::desc(dims_out, + dnn_support::toDnnType::type, + dnn_support::dims2Tag(dims_out)); + + auto x_mem = dnnl::memory(md_x, eng, x.data()); + auto y_mem = dnnl::memory(md_y, eng, y.data()); + + auto out_data = dev_ctx.template Alloc(out); + + auto out_mem = dnnl::memory(md_out, eng, out_data); + + auto oper_desc = dnnl::binary::desc(binary_type, md_x, md_y, md_out); + auto prim_desc = dnnl::binary::primitive_desc(oper_desc, eng); + auto prim = dnnl::binary(prim_desc); + + std::unordered_map binary_args; + binary_args.insert({DNNL_ARG_SRC_0, x_mem}); + binary_args.insert({DNNL_ARG_SRC_1, y_mem}); + binary_args.insert({DNNL_ARG_DST, out_mem}); + + prim.execute(engine_stream, binary_args); + engine_stream.wait(); +} + +template +void EqualityKernel(const phi::Context& dev_ctx, + std::string kernel_name, + dnnl::algorithm binary_type, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out, + const F& func, + const FF& float_func) { + if constexpr (std::is_same::value) { + RawCompareKernelDNN(dev_ctx, kernel_name, binary_type, x, y, axis, out); + } else { + RawCompareKernelSycl( + dev_ctx, kernel_name, x, y, axis, out, float_func, func); + } +} + +template +void CompareKernel(const phi::Context& dev_ctx, + std::string kernel_name, + dnnl::algorithm binary_type, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out, + const F& func) { + if constexpr (std::is_same::value) { + RawCompareKernelDNN(dev_ctx, kernel_name, binary_type, x, y, axis, out); + } else { + RawCompareKernelSycl(dev_ctx, kernel_name, x, y, axis, out, func, func); + } +} + +template +void NotEqualKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out) { + EqualityKernel( + dev_ctx, + "NotEqual", + dnnl::algorithm::binary_ne, + x, + y, + axis, + out, + [](T* x_data, T* y_data, bool* out_data, int64_t i) { + out_data[i] = x_data[i] != y_data[i]; + }, + [](T* x_data, T* y_data, bool* out_data, int64_t i) { + out_data[i] = static_cast( + std::fabs(static_cast(x_data[i] - y_data[i])) >= 1e-8); + }); +} + +template +void EqualKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out) { + EqualityKernel( + dev_ctx, + "Equal", + dnnl::algorithm::binary_eq, + x, + y, + axis, + out, + [](T* x_data, T* y_data, bool* out_data, int64_t i) { + out_data[i] = x_data[i] == y_data[i]; + }, + [](T* x_data, T* y_data, bool* out_data, int64_t i) { + out_data[i] = static_cast( + std::fabs(static_cast(x_data[i] - y_data[i])) < 1e-8); + }); +} + +template +void LessThanKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out) { + CompareKernel(dev_ctx, + "LessThanKernel", + dnnl::algorithm::binary_lt, + x, + y, + axis, + out, + [](T* x_data, T* y_data, bool* out_data, int64_t i) { + out_data[i] = x_data[i] < y_data[i]; + }); +} + +template +void LessEqualKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out) { + CompareKernel(dev_ctx, + "LessEqual", + dnnl::algorithm::binary_le, + x, + y, + axis, + out, + [](T* x_data, T* y_data, bool* out_data, int64_t i) { + out_data[i] = x_data[i] <= y_data[i]; + }); +} + +template +void GreaterThanKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out) { + CompareKernel(dev_ctx, + "GreaterThan", + dnnl::algorithm::binary_gt, + x, + y, + axis, + out, + [](T* x_data, T* y_data, bool* out_data, int64_t i) { + out_data[i] = x_data[i] > y_data[i]; + }); +} + +template +void GreaterEqualKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out) { + CompareKernel(dev_ctx, + "GreaterEqual", + dnnl::algorithm::binary_ge, + x, + y, + axis, + out, + [](T* x_data, T* y_data, bool* out_data, int64_t i) { + out_data[i] = x_data[i] >= y_data[i]; + }); +} + +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(not_equal, + intel_gpu, + ALL_LAYOUT, + custom_kernel::NotEqualKernel, + float, + double, + uint8_t, + int16_t, + int32_t, + int64_t, + bool) {} + +PD_BUILD_PHI_KERNEL(equal, + intel_gpu, + ALL_LAYOUT, + custom_kernel::EqualKernel, + float, + double, + uint8_t, + int16_t, + int32_t, + int64_t, + bool) {} + +PD_BUILD_PHI_KERNEL(less_than, + intel_gpu, + ALL_LAYOUT, + custom_kernel::LessThanKernel, + float, + double, + uint8_t, + int16_t, + int32_t, + int64_t, + bool) {} + +PD_BUILD_PHI_KERNEL(less_equal, + intel_gpu, + ALL_LAYOUT, + custom_kernel::LessEqualKernel, + float, + double, + uint8_t, + int16_t, + int32_t, + int64_t, + bool) {} + +PD_BUILD_PHI_KERNEL(greater_than, + intel_gpu, + ALL_LAYOUT, + custom_kernel::GreaterThanKernel, + float, + double, + uint8_t, + int16_t, + int32_t, + int64_t, + bool) {} + +PD_BUILD_PHI_KERNEL(greater_equal, + intel_gpu, + ALL_LAYOUT, + custom_kernel::GreaterEqualKernel, + float, + double, + uint8_t, + int16_t, + int32_t, + int64_t, + bool) {} diff --git a/backends/intel_gpu/kernels/dnn_support.hpp b/backends/intel_gpu/kernels/dnn_support.hpp index 43b4228ba..3102a724c 100644 --- a/backends/intel_gpu/kernels/dnn_support.hpp +++ b/backends/intel_gpu/kernels/dnn_support.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,9 +14,6 @@ // clang-format off #pragma once -#include -#include -#include #include #include #include @@ -113,3 +110,188 @@ const T* shortPath(const T* p) { ss << "[" << shortPath(__FILE__) << ":" << __LINE__ << "] :" << x; \ throw std::runtime_error(ss.str()); \ } + +template +std::ostream& operator<<(std::ostream& o, const std::vector& v) { + o << "{ "; + for (auto item : v) { + o << item << ","; + } + o << " }"; + return o; +} + +namespace dnn_support { +template +struct toDnnType {}; + +template <> +struct toDnnType { + static const dnnl::memory::data_type type = dnnl::memory::data_type::s32; +}; + +template <> +struct toDnnType { + static const dnnl::memory::data_type type = dnnl::memory::data_type::f32; +}; + +template <> +struct toDnnType { + static const dnnl::memory::data_type type = dnnl::memory::data_type::bf16; +}; + +#ifdef CUSTOM_DNN + +template <> +struct toDnnType { + static const dnnl::memory::data_type type = dnnl::memory::data_type::f64; +}; + +#endif + +template +dnnl::memory::format_tag dims2Tag(const T& d) { + switch (d.size()) { + case 1: + return dnnl::memory::format_tag::a; + case 2: + return dnnl::memory::format_tag::ab; + case 3: + return dnnl::memory::format_tag::abc; + case 4: + return dnnl::memory::format_tag::abcd; + case 5: + return dnnl::memory::format_tag::abcde; + case 6: + return dnnl::memory::format_tag::abcdef; + + default: + show_error("This size is not supported size=" << d.size()); + } + return dnnl::memory::format_tag::a; +} + +template > +dnnl::memory::format_tag axis2Tag(const T& d) { + switch (d.size()) { + case 1: + return dnnl::memory::format_tag::a; + case 2: + if (d == T{1, 0}) { + return dnnl::memory::format_tag::ba; + } + return dnnl::memory::format_tag::ab; + + case 3: + + if (d == T{0, 2, 1}) { + return dnnl::memory::format_tag::acb; + } + + if (d == T{1, 0, 2}) { + return dnnl::memory::format_tag::bac; + } + + if (d == T{2, 1, 0}) { + return dnnl::memory::format_tag::cba; + } + + if (d == T{0, 1, 2}) { + return dnnl::memory::format_tag::abc; + } + + rise_error("Can't convert tag for " << d); + + case 4: + + if (d == T{0, 1, 3, 2}) { + return dnnl::memory::format_tag::abdc; + } + + if (d == T{0, 3, 1, 2}) { + return dnnl::memory::format_tag::adbc; + } + + if (d == T{0, 2, 1, 3}) { + return dnnl::memory::format_tag::acbd; + } + + if (d == T{0, 1, 2, 3}) { + return dnnl::memory::format_tag::abcd; + } + rise_error("Can't convert tag for " << d); + + case 5: + + if (d == T{0, 1, 2, 3, 4}) { + return dnnl::memory::format_tag::abcde; + } + + rise_error("Can't convert tag for " << d); + + default: + show_error("This size is not supported size=" << d.size()); + rise_error("Lack of support " << d); + } + return dnnl::memory::format_tag::a; +} + +template +struct type2String; + +template <> +struct type2String { + constexpr static const char* name() { return "double"; } +}; + +template <> +struct type2String { + constexpr static const char* name() { return "float"; } +}; + +template <> +struct type2String { + constexpr static const char* name() { return "int32_t"; } +}; + +template <> +struct type2String { + constexpr static const char* name() { return "int64_t"; } +}; + +template <> +struct type2String { + constexpr static const char* name() { return "bool"; } +}; + +template <> +struct type2String { + constexpr static const char* name() { return "uchar"; } +}; + +template <> +struct type2String { + constexpr static const char* name() { return "short"; } +}; + +template <> +struct type2String { + constexpr static const char* name() { return "signed char"; } +}; + +template +dnnl::memory::dims computeStrides(const std::vector& dims, + const std::vector& axis) { + size_t rank = axis.size(); + std::vector strides(rank); + unsigned int total_stride = 1; + for (int i = rank - 1; i >= 0; --i) { + strides[axis[i]] = total_stride; + total_stride *= dims[axis[i]]; + } + show_debug("computeStrides strides=" << strides << " from [ dims=" << dims + << " axis=" << axis << "]"); + return strides; +} + +} // namespace dnn_support diff --git a/backends/intel_gpu/kernels/elementwise_kernel.cc b/backends/intel_gpu/kernels/elementwise_kernel.cc new file mode 100644 index 000000000..abac8f93a --- /dev/null +++ b/backends/intel_gpu/kernels/elementwise_kernel.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/dnn_support.hpp" +#include "kernels/phi_funcs.h" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +template +void MultiplyRawKernelGPU(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out) { + show_kernel( + "ElementWise-SYCL-MUL type=" << dnn_support::type2String::name()); + void* stream = const_cast(dev_ctx.stream()); + auto* q = static_cast(stream); + + T* out_data = dev_ctx.Alloc(out); + + auto NOUT = out->numel(); + + auto input_x = x.data(); + auto input_y = y.data(); + + q->submit([&](sycl::handler& h) { + h.parallel_for(NOUT, [input_x, input_y, out_data](sycl::id<1> i) { + out_data[i] = input_x[i] * input_y[i]; + }); + }); + + q->wait(); +} + +template +void MultiplyKernelGPU(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + phi::DenseTensor* out) { + int axis = -1; + MultiplyRawKernelGPU(dev_ctx, x, y, axis, out); +} + +template +void MultiplyOneDNNRawKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out) { + show_kernel( + "ElementWise-ONEDNN type=" << dnn_support::type2String::name()); + auto* q = static_cast(const_cast(dev_ctx.stream())); + + using tag = dnnl::memory::format_tag; + using dt = dnnl::memory::data_type; + + auto eng = dnnl::sycl_interop::make_engine(q->get_device(), q->get_context()); + auto engine_stream = dnnl::sycl_interop::make_stream(eng, *q); + + dnnl::memory::dims dims_x = x.dims(); + dnnl::memory::dims dims_y = y.dims(); + dnnl::memory::dims dims_out = out->dims(); + + phi::update_broadcast(dims_x, dims_y, axis); + + auto md_x = dnnl::memory::desc( + dims_x, dnn_support::toDnnType::type, dnn_support::dims2Tag(dims_x)); + + auto md_y = dnnl::memory::desc( + dims_y, dnn_support::toDnnType::type, dnn_support::dims2Tag(dims_y)); + auto md_out = dnnl::memory::desc(dims_out, + dnn_support::toDnnType::type, + dnn_support::dims2Tag(dims_out)); + + auto x_mem = dnnl::memory(md_x, eng, x.data()); + auto y_mem = dnnl::memory(md_y, eng, y.data()); + + auto out_data = dev_ctx.template Alloc(out); + + auto out_mem = dnnl::memory(md_out, eng, out_data); + + auto oper_desc = + dnnl::binary::desc(dnnl::algorithm::binary_mul, md_x, md_y, md_out); + auto prim_desc = dnnl::binary::primitive_desc(oper_desc, eng); + auto prim = dnnl::binary(prim_desc); + + std::unordered_map binary_args; + binary_args.insert({DNNL_ARG_SRC_0, x_mem}); + binary_args.insert({DNNL_ARG_SRC_1, y_mem}); + binary_args.insert({DNNL_ARG_DST, out_mem}); + + prim.execute(engine_stream, binary_args); + engine_stream.wait(); +} + +template +void MultiplyOneDNNKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + phi::DenseTensor* out) { + int axis = -1; + MultiplyOneDNNRawKernel(dev_ctx, x, y, axis, out); +} + +template +void MultiplyMainRaw(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + int axis, + phi::DenseTensor* out) { + if constexpr (std::is_same::value || std::is_same::value + //|| std::is_same::value + ) { + MultiplyOneDNNRawKernel(dev_ctx, x, y, axis, out); + } else { + MultiplyRawKernelGPU(dev_ctx, x, y, axis, out); + } +} +template +void MultiplyMain(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + phi::DenseTensor* out) { + int axis = -1; + MultiplyMainRaw(dev_ctx, x, y, axis, out); +} + +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(multiply_raw, + intel_gpu, + ALL_LAYOUT, + custom_kernel::MultiplyMainRaw, + int32_t, + int64_t, + float, + double) {} + +PD_BUILD_PHI_KERNEL(multiply, + intel_gpu, + ALL_LAYOUT, + custom_kernel::MultiplyMain, + int32_t, + int64_t, + float, + double) {} diff --git a/backends/intel_gpu/kernels/full_kernel.cc b/backends/intel_gpu/kernels/full_kernel.cc new file mode 100644 index 000000000..f4f4c3d8a --- /dev/null +++ b/backends/intel_gpu/kernels/full_kernel.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/dnn_support.hpp" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +template +void FullValue(const phi::Context& dev_ctx, + phi::DenseTensor* tensor, + VType val) { + show_kernel("FullValue type=" << dnn_support::type2String::name()); + auto t = dev_ctx.template Alloc(tensor); + auto* q = static_cast(dev_ctx.stream()); + auto num = tensor->numel(); + show_debug("FullValue size=" << num << " sizeof(T)=" << sizeof(T)); + auto e = q->submit([&](sycl::handler& h) { h.fill(t, val, num); }); + q->wait(); +} + +template +void FullKernel(const phi::Context& dev_ctx, + const phi::IntArray& shape, + const phi::Scalar& val, + phi::DataType dtype, + phi::DenseTensor* out) { + auto int_shape = shape.GetData(); + out->Resize(std::vector(int_shape.cbegin(), int_shape.cend())); + FullValue(dev_ctx, out, val.to()); +} +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(full, + intel_gpu, + ALL_LAYOUT, + custom_kernel::FullKernel, + float, + double, + uint8_t, + int16_t, + int32_t, + int64_t, + bool) {} diff --git a/backends/intel_gpu/kernels/kernels.h b/backends/intel_gpu/kernels/kernels.h new file mode 100644 index 000000000..10f964aa0 --- /dev/null +++ b/backends/intel_gpu/kernels/kernels.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/capi/all.h" +namespace custom_kernel { +template +void TransposeKernelGPU(const phi::Context& ctx, + const phi::DenseTensor& x, + const std::vector& axis, + phi::DenseTensor* out); + +template +void SoftmaxKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + int axis, + phi::DenseTensor* out); +} // namespace custom_kernel diff --git a/backends/intel_gpu/kernels/memcpy_kernel.cc b/backends/intel_gpu/kernels/memcpy_kernel.cc new file mode 100644 index 000000000..340a03392 --- /dev/null +++ b/backends/intel_gpu/kernels/memcpy_kernel.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/dnn_support.hpp" +#include "kernels/phi_funcs.h" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +template +void MemcpyD2HKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + int dst_place_type, + phi::DenseTensor* out) { + show_kernel("memcpy_d2h"); + auto out_data = dev_ctx.HostAlloc(out); + auto x_data = x.data(); + void* stream = const_cast(dev_ctx.stream()); + auto* q = static_cast(stream); + show_debug("memcpy_d2h -> memcpy(to=" << std::hex << out_data << ", from=" + << x_data << ", size=" << std::dec + << x.memory_size() << ")"); + q->memcpy(out_data, x_data, x.memory_size()); +} + +template +void MemcpyH2DKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + int dst_place_type, + phi::DenseTensor* out) { + show_kernel("memcpy_h2d"); + auto out_data = dev_ctx.Alloc(out); + auto x_data = x.data(); + + void* stream = const_cast(dev_ctx.stream()); + auto* q = static_cast(stream); + show_debug("memcpy_h2d -> memcpy(to=" << std::hex << out_data << ", from=" + << x_data << ", size=" << std::dec + << x.memory_size() << ")"); + q->memcpy(out_data, x_data, x.memory_size()); +} + +template +void MemcpyKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + int dst_place_type, + phi::DenseTensor* out) { + if (!x.initialized()) { + return; + } + // The dst_place_type is defined in paddle/fluid/operators/memcpy.h: + // CPU = 0, CUDA = 1, CUDA_PINNED = 2, + // XPU = 3, NPU = 4, NPU_PINNED = 5, + // CUSTOM_DEVICE = 6 + if (dst_place_type == 0) { // CPU + MemcpyD2HKernel(dev_ctx, x, 0, out); + } else if (dst_place_type == 6) { // custom_device + MemcpyH2DKernel(dev_ctx, x, 6, out); + } +} + +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(memcpy_d2h, + intel_gpu, + ALL_LAYOUT, + custom_kernel::MemcpyD2HKernel, + float, + double, + int32_t, + int64_t, + bool) {} + +PD_BUILD_PHI_KERNEL(memcpy_h2d, + intel_gpu, + ALL_LAYOUT, + custom_kernel::MemcpyH2DKernel, + float, + double, + int32_t, + int64_t, + bool) {} + +PD_BUILD_PHI_KERNEL(memcpy, + intel_gpu, + ALL_LAYOUT, + custom_kernel::MemcpyKernel, + phi::dtype::float16, + float, + double, + int, + int64_t, + bool) {} diff --git a/backends/intel_gpu/kernels/phi_funcs.h b/backends/intel_gpu/kernels/phi_funcs.h new file mode 100644 index 000000000..c585ba20f --- /dev/null +++ b/backends/intel_gpu/kernels/phi_funcs.h @@ -0,0 +1,391 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/phi/capi/all.h" + +namespace phi { + +template +T TolerableValue(const T& x) { + const T kApproInf = 1e20; + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; + return x; +} + +template +inline std::string to_string(const T& val) { + std::stringstream ss; + ss << val; + return ss.str(); +} + +template <> +inline std::string to_string(const phi::DataType& val) { + if (val == phi::DataType::FLOAT32) { + return "float32"; + } else if (val == phi::DataType::FLOAT64) { + return "float64"; + } else if (val == phi::DataType::INT32) { + return "int32"; + } else if (val == phi::DataType::INT64) { + return "int64"; + } else { + return "undefined"; + } +} + +template <> +inline std::string to_string(const phi::DataLayout& val) { + if (val == phi::DataLayout::NCHW) { + return "nchw"; + } else { + return "undefined"; + } +} + +template +inline std::string to_string(const std::vector& vec) { + std::stringstream ss; + for (auto i = 0; i < vec.size(); ++i) { + ss << to_string(vec[i]); + if (i < vec.size() - 1) { + ss << ", "; + } + } + return ss.str(); +} + +inline std::vector slice_ddim(const std::vector& dim, + int begin, + int end) { + return std::vector(dim.cbegin() + begin, dim.cbegin() + end); +} + +template +inline int64_t product(const std::vector& ddim) { + return std::accumulate(ddim.cbegin(), ddim.cend(), 1, std::multiplies()); +} + +template +T vec_product(const std::vector& a, const std::vector& b) { + T ret = 0; + for (auto i = 0; i < a.size(); ++i) { + ret += a[i] * b[i]; + } + return ret; +} + +template +T vec_product(const T* a, const T* b, size_t a_size) { + T ret = 0; + for (auto i = 0; i < a_size; ++i) { + ret += a[i] * b[i]; + } + return ret; +} +namespace funcs { + +inline int CanonicalAxis(const int axis, const int rank) { + if (axis < 0) { + return axis + rank; + } + return axis; +} + +inline int SizeToAxis(const int axis, std::vector dims) { + int size = 1; + for (int i = 0; i < axis; i++) { + size *= dims[i]; + } + return size; +} + +inline int SizeFromAxis(const int axis, std::vector dims) { + int size = 1; + for (int i = axis; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +inline int SizeOutAxis(const int axis, std::vector dims) { + int size = 1; + for (int i = axis + 1; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +template +inline void CheckAndUpdateSliceAttrs(const std::vector& in_dims, + const std::vector& axes, + std::vector* starts, + std::vector* ends, + std::vector* steps = nullptr, + std::vector* infer_flags = nullptr) { + for (size_t i = 0; i < axes.size(); ++i) { + T axis = axes[i]; + PD_CHECK(axis < in_dims.size(), + "The axis value should be less than the rank of input, " + "but received axes[%d] = %d, rank of input is %d.", + i, + axis, + in_dims.size()); + + if (infer_flags != nullptr && (*infer_flags)[i] == -1) { + continue; + } + + T dim_value = in_dims[axis]; + + if (dim_value > 0) { + T step = steps == nullptr ? 1 : (*steps)[i]; + PD_CHECK( + step != 0, "Step should not be 0, but received step = %d.", step); + + T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; + start = std::max(start, static_cast(0)); + + T end = + 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; + end = std::min(end, dim_value); + + if (step > 0) { + start = std::min(start, dim_value); + end = std::max(end, static_cast(0)); + PD_CHECK(end >= start, + "When step > 0, end should be greater than start, but " + "received end = %d, start = %d.", + end, + start); + } else { + // NOTE(liym27): When step < 0, start should less and equal to + // dim_value-1 + // "end is -1" means contain the 0-th element of this axis. + start = std::min(start, dim_value - 1); + end = std::max(end, static_cast(-1)); + PD_CHECK(start >= end, + "When step < 0, start should be greater than end, but " + "received start = %d, end = %d.", + start, + end); + } + + (*starts)[i] = start; + (*ends)[i] = end; + } else if (dim_value == 0) { + (*starts)[i] = 0; + (*ends)[i] = 0; + } + } +} + +template +inline std::vector GetSliceDims( + const std::vector& in_dims, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + std::vector* steps = nullptr, + std::vector* infer_flags = nullptr) { + std::vector slice_dims = in_dims; + + for (size_t i = 0; i < axes.size(); ++i) { + T axis = axes[i]; + if (infer_flags != nullptr && (*infer_flags)[i] == -1) { + slice_dims[axis] = -1; + continue; + } + + T start = starts[i]; + T end = ends[i]; + T step = steps == nullptr ? 1 : (*steps)[i]; + + if (step > 0) { + slice_dims[axis] = (end - start + step - 1) / step; + } else { + slice_dims[axis] = (end - start + step + 1) / step; + } + } + return slice_dims; +} + +template +inline std::vector GetDecreasedDims( + const std::vector& slice_dims, + const std::vector& decrease_axes, + std::vector* infer_flags = nullptr) { + std::vector decreased_dims = slice_dims; + std::vector decrease_flag(slice_dims.size(), 0); + if (decrease_axes.size() > 0) { + for (size_t i = 0; i < decrease_axes.size(); ++i) { + T axis = decrease_axes[i]; + decrease_flag[axis] = 1; + if (infer_flags && (*infer_flags)[i] != -1) { + PD_CHECK(decreased_dims[axis] == 1, + "Decrease dim should be 1, but now received %d", + decreased_dims[axis]); + } + } + + std::vector new_shape; + for (int i = 0; i < decreased_dims.size(); ++i) { + if (decrease_flag[i] == 0) { + new_shape.push_back(decreased_dims[i]); + } + } + + // NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and + // uses [1] instead. + if (new_shape.size() == 0) { + new_shape.push_back(1); + } + decreased_dims = std::vector(new_shape.cbegin(), new_shape.cend()); + } + return decreased_dims; +} + +} // namespace funcs + +template +inline void BroadcastTo(const phi::Context& dev_ctx, + const phi::DenseTensor& in, + std::vector out_dims, + int axis, + phi::DenseTensor* out) { + auto in_dims = in.dims(); + + if (in_dims.size() == out_dims.size()) { + bool broadcast = false; + for (auto i = 0; i < in_dims.size(); ++i) { + if (in_dims[i] != out_dims[i]) { + broadcast = true; + break; + } + } + if (!broadcast) { + out->ShareDataWith(in); + return; + } + } + + out->Resize(out_dims); + auto out_data = dev_ctx.template Alloc(out); + auto in_data = in.data(); + + axis = axis == -1 ? std::abs(static_cast(in_dims.size()) - + static_cast(out_dims.size())) + : axis; + auto retain = static_cast(out_dims.size()) - axis - + static_cast(in_dims.size()); + std::vector tmp_dims; + for (auto i = 0; i < axis; ++i) { + tmp_dims.push_back(1); + } + tmp_dims.insert(tmp_dims.end(), in_dims.cbegin(), in_dims.cend()); + for (auto i = 0; i < retain; ++i) { + tmp_dims.push_back(1); + } + + auto numel = out->numel(); + std::vector index(out_dims.size(), 0); + std::vector in_step(tmp_dims.size(), 1); + std::vector out_step(out_dims.size(), 1); + for (auto i = tmp_dims.size() - 1; i > 0; --i) { + in_step[i - 1] = in_step[i] * tmp_dims[i]; + } + for (auto i = out_dims.size() - 1; i > 0; --i) { + out_step[i - 1] = out_step[i] * out_dims[i]; + } + + for (auto i = 0; i < numel; ++i) { + auto src_index = index; + for (auto j = 0; j < tmp_dims.size(); ++j) { + if (tmp_dims[j] == 1) { + src_index[j] = 0; + } + } + + out_data[phi::vec_product(index, out_step)] = + in_data[phi::vec_product(src_index, in_step)]; + + index.back()++; + for (auto j = index.size() - 1; j > 0; --j) { + if (index[j] >= out_dims[j]) { + index[j] = 0; + index[j - 1]++; + } else { + break; + } + } + } +} + +inline std::vector BroadcastDims(int axis, + const std::vector& x_dims, + const std::vector& y_dims) { + axis = (axis == -1 ? std::abs(static_cast(x_dims.size()) - + static_cast(y_dims.size())) + : axis); + std::vector dst_dims; + if (x_dims.size() == y_dims.size()) { + for (auto i = 0; i < x_dims.size(); ++i) { + dst_dims.push_back(std::max(x_dims[i], y_dims[i])); + } + } else if (x_dims.size() >= y_dims.size()) { + dst_dims = x_dims; + for (auto i = 0; i < y_dims.size(); ++i) { + dst_dims[axis + i] = std::max(dst_dims[axis + i], y_dims[i]); + } + } else { + dst_dims = y_dims; + for (auto i = 0; i < x_dims.size(); ++i) { + dst_dims[axis + i] = std::max(dst_dims[axis + i], x_dims[i]); + } + } + + return dst_dims; +} + +void inline align_broadcast(std::vector& from, // NOLINT + std::vector& to, // NOLINT + int axis) { + std::vector tmp(from.size(), 1); + std::copy(to.begin(), + to.end(), + tmp.begin() + ((axis == -1) ? (from.size() - to.size()) : axis)); + to = std::move(tmp); +} + +inline void update_broadcast(std::vector& x, // NOLINT + std::vector& y, // NOLINT + int axis) { + if (x.size() == y.size()) return; + + if (x.size() > y.size()) { + align_broadcast(x, y, axis); + } else { + align_broadcast(y, x, axis); + } +} + +} // namespace phi diff --git a/backends/intel_gpu/kernels/reduce_kernel.cc b/backends/intel_gpu/kernels/reduce_kernel.cc new file mode 100644 index 000000000..31ce0174c --- /dev/null +++ b/backends/intel_gpu/kernels/reduce_kernel.cc @@ -0,0 +1,231 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/dnn_support.hpp" +#include "kernels/phi_funcs.h" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +template +void ReduceKernel(const phi::Context& dev_ctx, + std::string kernel_name, + const phi::DenseTensor& x, + const std::vector& dims, + dnnl::algorithm reduction_type, + bool keep_dim, + bool reduce_all, + phi::DenseTensor* out) { + auto x_dims = x.dims(); + auto reduce_dims = dims; + + if (reduce_dims.size() == 0) { + reduce_all = true; + } + + if (reduce_all) { + reduce_dims = std::vector(x_dims.size(), 1); + } else { + auto output_dims(x_dims); + for (size_t i = 0; i < reduce_dims.size(); ++i) { + // handle negative dims, f.e. "-1" means rightmost dimension + int index = (reduce_dims[i] >= 0) ? reduce_dims[i] + : x_dims.size() + reduce_dims[i]; + output_dims[index] = 1; + } + reduce_dims = output_dims; + } + show_kernel( + kernel_name << "-Sycl type=" << dnn_support::type2String::name() + << ", reduce_all=" << reduce_all << ", x_dims=" << x_dims + << ", dims=" << dims << ", keep_dim=" << keep_dim + << ", reduce_dims=" << reduce_dims); + void* stream = const_cast(dev_ctx.stream()); + auto* q = static_cast(const_cast(dev_ctx.stream())); + auto out_data = dev_ctx.template Alloc(out); + + if (x_dims == reduce_dims) { + auto x_data = x.data(); + show_debug(kernel_name << " -> memcpy(to=" << std::hex << out_data + << ", from=" << x_data << ", size=" << std::dec + << x.memory_size() << ")"); + q->memcpy(out_data, x_data, x.memory_size()); + } else { + using tag = dnnl::memory::format_tag; + using dt = dnnl::memory::data_type; + + auto eng = + dnnl::sycl_interop::make_engine(q->get_device(), q->get_context()); + auto engine_stream = dnnl::sycl_interop::make_stream(eng, *q); + + dnnl::memory::dims dims_x = x.dims(); + dnnl::memory::dims dims_out = reduce_dims; + + auto md_x = dnnl::memory::desc( + dims_x, dnn_support::toDnnType::type, dnn_support::dims2Tag(dims_x)); + + auto md_out = dnnl::memory::desc(reduce_dims, + dnn_support::toDnnType::type, + dnn_support::dims2Tag(reduce_dims)); + + auto x_mem = dnnl::memory(md_x, eng, x.data()); + + auto out_mem = dnnl::memory(md_out, eng, out_data); + + auto oper_desc = + dnnl::reduction::desc(reduction_type, md_x, md_out, 0.f, 0.f); + auto prim_desc = dnnl::reduction::primitive_desc(oper_desc, eng); + + auto reduction_prim = dnnl::reduction(prim_desc); + + std::unordered_map reduction_args; + reduction_args.insert({DNNL_ARG_SRC, x_mem}); + reduction_args.insert({DNNL_ARG_DST, out_mem}); + + reduction_prim.execute(engine_stream, reduction_args); + engine_stream.wait(); + } +} + +template +void MeanRawKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + phi::DenseTensor* out) { + ReduceKernel(dev_ctx, + "MeanRaw", + x, + dims, + dnnl::algorithm::reduction_mean, + keep_dim, + reduce_all, + out); +} + +template +void MeanKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const std::vector& dims, + bool keep_dim, + phi::DenseTensor* out) { + bool reduce_all = false; + MeanRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); +} + +template +void SumRawKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + phi::DataType out_dtype, + phi::DenseTensor* out) { + ReduceKernel(dev_ctx, + "SumRaw", + x, + dims, + dnnl::algorithm::reduction_sum, + keep_dim, + reduce_all, + out); +} + +template +void SumKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const std::vector& dims, + phi::DataType out_dtype, + bool keep_dim, + phi::DenseTensor* out) { + bool reduce_all = false; + SumRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out); +} + +template +void MaxRawKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + phi::DenseTensor* out) { + ReduceKernel(dev_ctx, + "MaxRaw", + x, + dims, + dnnl::algorithm::reduction_max, + keep_dim, + reduce_all, + out); +} + +template +void MaxKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const std::vector& dims, + bool keep_dim, + phi::DenseTensor* out) { + bool reduce_all = false; + MaxRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); +} + +template +void MinRawKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + phi::DenseTensor* out) { + ReduceKernel(dev_ctx, + "MinRaw", + x, + dims, + dnnl::algorithm::reduction_min, + keep_dim, + reduce_all, + out); +} + +template +void MinKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const std::vector& dims, + bool keep_dim, + phi::DenseTensor* out) { + bool reduce_all = false; + MinRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); +} + +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL( + mean_raw, intel_gpu, ALL_LAYOUT, custom_kernel::MeanRawKernel, float) {} +PD_BUILD_PHI_KERNEL( + mean, intel_gpu, ALL_LAYOUT, custom_kernel::MeanKernel, float) {} + +PD_BUILD_PHI_KERNEL( + sum_raw, intel_gpu, ALL_LAYOUT, custom_kernel::SumRawKernel, float) {} +PD_BUILD_PHI_KERNEL( + sum, intel_gpu, ALL_LAYOUT, custom_kernel::SumKernel, float) {} + +PD_BUILD_PHI_KERNEL( + min_raw, intel_gpu, ALL_LAYOUT, custom_kernel::MinRawKernel, float) {} +PD_BUILD_PHI_KERNEL( + min, intel_gpu, ALL_LAYOUT, custom_kernel::MinKernel, float) {} + +PD_BUILD_PHI_KERNEL( + max_raw, intel_gpu, ALL_LAYOUT, custom_kernel::MaxRawKernel, float) {} +PD_BUILD_PHI_KERNEL( + max, intel_gpu, ALL_LAYOUT, custom_kernel::MaxKernel, float) {} diff --git a/backends/intel_gpu/kernels/reshape_kernel.cc b/backends/intel_gpu/kernels/reshape_kernel.cc new file mode 100644 index 000000000..12ad3d757 --- /dev/null +++ b/backends/intel_gpu/kernels/reshape_kernel.cc @@ -0,0 +1,189 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "kernels/dnn_support.hpp" +#include "kernels/phi_funcs.h" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +static std::vector ValidateShape(const std::vector shape, + const std::vector& in_dims) { + const int64_t in_size = phi::product(in_dims); + std::vector in_dims_vec = in_dims; + bool all_positive = std::all_of(in_dims_vec.cbegin(), + in_dims_vec.cend(), + [](int64_t i) { return i > 0; }); + // only one dimension can be set to -1, whose size will be automatically + // infered. + const int64_t unk_dim_val = -1; + const int64_t copy_dim_val = 0; + + std::vector output_shape(shape.size(), 0); + int64_t capacity = 1; + int unk_dim_idx = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == unk_dim_val) { + PD_CHECK(unk_dim_idx == -1, + "Only one dimension value of 'shape' in ReshapeOp can " + "be -1. But received shape = [%s], shape[%d] is also -1.", + phi::to_string(shape), + i); + unk_dim_idx = i; + } else if (shape[i] == copy_dim_val) { + PD_CHECK(static_cast(i) < in_dims.size(), + "The index of 0 in `shape` must be less than " + "the input tensor X's dimensions. " + "But received shape = [%s], shape[%d] = 0, X's shape = [%s], " + "X's dimensions = %d.", + phi::to_string(shape), + i, + phi::to_string(in_dims), + in_dims.size()); + } else { + PD_CHECK(shape[i] > 0, + "Each dimension value of 'shape' in ReshapeOp must not " + "be negative except one unknown dimension. " + "But received shape = [%s], shape[%d] = %d.", + phi::to_string(shape), + i, + shape[i]); + } + + // NOTE all non-zero values will be converted to True (include negative + // value) + capacity *= (shape[i] ? shape[i] : in_dims[i]); + output_shape[i] = (shape[i] ? static_cast(shape[i]) : in_dims[i]); + } + + if (unk_dim_idx != -1) { + if (all_positive) { + // in_size < 0 and is un-determinate in compile time, skip the check, + // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], + // capacity = -24, in_size = -8, output_shape[0] = 0 + // the following check will fail. + output_shape[unk_dim_idx] = -in_size / capacity; + PD_CHECK(output_shape[unk_dim_idx] * capacity == -in_size, + "The 'shape' attribute in ReshapeOp is invalid. " + "The input tensor X'size must be divisible by known " + "capacity of 'shape'. " + "But received X's shape = [%s], X's size = %d, " + "'shape' is [%s], known capacity of 'shape' is %d.", + phi::to_string(in_dims), + in_size, + phi::to_string(shape), + capacity); + } else { + output_shape[unk_dim_idx] = -1; + } + } else { + if (all_positive) { + PD_CHECK(capacity == in_size, + "The 'shape' in ReshapeOp is invalid. " + "The input tensor X'size must be equal to the capacity of " + "'shape'. " + "But received X's shape = [%s], X's size = %d, 'shape' is " + "[%s], the capacity of 'shape' is %d.", + phi::to_string(in_dims), + in_size, + phi::to_string(shape), + capacity); + } + } + + // support reshape with zero-input(input tensor with product(shape) == 0) + // by now we require that if the input tensor is zero shape, the target + // shape of output must be zero + if (in_size == 0) { + PD_CHECK(capacity < in_size, + "The 'shape' in ReshapeOp is invalid. " + "The input tensor X's shape = [%s], X's capacity = %d." + "But the target shape of Out is [%s], the " + "capacity of 'Out' is %d.", + phi::to_string(in_dims), + in_size, + phi::to_string(shape), + capacity); + } + + return output_shape; +} + +template +void ReshapeKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::IntArray& shape, + phi::DenseTensor* out) { + show_kernel("Reshape type=" << dnn_support::type2String::name()); + auto x_dims = x.dims(); + auto out_dims = ValidateShape(shape.GetData(), x_dims); + out->Resize(out_dims); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + if (x.dims()[0] == out->dims()[0]) { + out->share_lod(x); + } + + if (!(x.initialized() && x.Holder() == out->Holder())) { + show_debug( + "Reshape type initialized=" << dnn_support::type2String::name()); + dev_ctx.Alloc(out, x.dtype()); + auto dims = out->dims(); + auto x_data = x.data(); + auto out_data = out->data(); + + void* stream = const_cast(dev_ctx.stream()); + auto* q = static_cast(stream); + q->memcpy(out_data, x_data, x.numel() * sizeof(T)); + + out->Resize(dims); + out->ResetLoD(x.lod()); + } +} + +template +void ReshapeWithXShape(const phi::Context& dev_ctx, + const phi::DenseTensor& x, + const phi::IntArray& shape, + phi::DenseTensor* out, + phi::DenseTensor* xshape) { + ReshapeKernel(dev_ctx, x, shape, out); +} + +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(reshape, + intel_gpu, + ALL_LAYOUT, + custom_kernel::ReshapeKernel, + float, + double, + int8_t, + int16_t, + int32_t, + int64_t, + uint8_t, + bool) {} + +PD_BUILD_PHI_KERNEL(reshape_with_xshape, + intel_gpu, + ALL_LAYOUT, + custom_kernel::ReshapeWithXShape, + float, + double, + int8_t, + int16_t, + int32_t, + int64_t, + uint8_t, + bool) {} diff --git a/backends/intel_gpu/kernels/slice_kernel.cc b/backends/intel_gpu/kernels/slice_kernel.cc new file mode 100644 index 000000000..8a6b9d6f4 --- /dev/null +++ b/backends/intel_gpu/kernels/slice_kernel.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "kernels/dnn_support.hpp" +#include "kernels/phi_funcs.h" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +template +void SliceRawKernel(const phi::Context& ctx, + const phi::DenseTensor& input, + const std::vector& axes, + const phi::IntArray& starts_arr, + const phi::IntArray& ends_arr, + const std::vector& infer_flags, + const std::vector& decrease_axis, + phi::DenseTensor* out) { + show_kernel("SliceRawKernel, type=" << dnn_support::type2String::name()); + + // Step 1: Get the accurate attribute value of starts and ends + auto starts = starts_arr.GetData(); + auto ends = ends_arr.GetData(); + PD_CHECK(starts.size() == axes.size(), + "The size of starts must be equal to the size of axes."); + PD_CHECK(ends.size() == axes.size(), + "The size of ends must be equal to the size of axes."); + + void* stream = const_cast(ctx.stream()); + auto* q = static_cast(stream); + + // Step 2: Compute output + auto in = &input; + auto in_data = input.data(); + int rank = input.dims().size(); + + auto in_dims = in->dims(); + auto out_dims = out->dims(); + auto slice_dims = out_dims; + + // 2.1 Infer output dims + for (size_t i = 0; i < axes.size(); ++i) { + // when start == -1 && end == start+1 + if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) { + auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); + if (ret != decrease_axis.end()) { + ends[i] = in_dims[axes[i]]; + } + } + } + + phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); + slice_dims = phi::funcs::GetSliceDims( + in_dims, axes, starts, ends, nullptr, nullptr); + out_dims = phi::funcs::GetDecreasedDims(slice_dims, decrease_axis); + + // 2.2 Get output + auto offsets = std::vector(rank); + auto extents = std::vector(rank); + + for (size_t i = 0; i < rank; ++i) { + offsets[i] = 0; + extents[i] = slice_dims[i]; + } + for (size_t i = 0; i < axes.size(); ++i) { + offsets[axes[i]] = starts[i]; + } + + out->Resize(slice_dims); + auto out_data = ctx.template Alloc(out); + + std::vector in_step(rank, 1); + + for (auto i = rank - 1; i > 0; --i) { + in_step[i - 1] = in_step[i] * in_dims[i]; + } + + auto numel = phi::product(slice_dims); + auto index = std::vector(offsets.cbegin(), offsets.cend()); + + std::vector temp(numel / slice_dims.back() * 2, 0); + { + sycl::buffer temp_buf(temp); + sycl::buffer index_buf(index); + sycl::buffer offsets_buf(offsets); + sycl::buffer extents_buf(extents); + sycl::buffer in_step_buf(in_step); + + auto e1 = q->submit([&](sycl::handler& h) { + sycl::accessor a_index(index_buf, h, sycl::read_write); + sycl::accessor a_temp(temp_buf, h, sycl::write_only, sycl::no_init); + sycl::accessor a_offsets(offsets_buf, h, sycl::read_only); + sycl::accessor a_extents(extents_buf, h, sycl::read_only); + sycl::accessor a_in_step(in_step_buf, h, sycl::read_only); + h.single_task([numel, + a_index, + a_offsets, + a_extents, + a_in_step, + slice_dims_back = slice_dims.back(), + a_temp, + out, + index_size = index.size()]() { + for (auto i = 0; i < numel; i += slice_dims_back) { + auto wyn = phi::vec_product(&a_index[0], &a_in_step[0], index_size); + a_temp[i / slice_dims_back * 2] = i; + a_temp[i / slice_dims_back * 2 + 1] = wyn; + + a_index[index_size - 2]++; + for (auto j = index_size - 2; j > 0; --j) { + if (a_index[j] >= a_offsets[j] + a_extents[j]) { + a_index[j] = a_offsets[j]; + a_index[j - 1] += 1; + } else { + break; + } + } + } + }); + }); + } + + for (auto i = 0; i < numel / slice_dims.back() * 2; i += 2) { + q->submit([&](sycl::handler& h) { + h.memcpy(out_data + temp[i], + in_data + temp[i + 1], + sizeof(T) * slice_dims.back()); + }); + } + q->wait(); + + out->Resize(out_dims); +} + +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(slice, + intel_gpu, + ALL_LAYOUT, + custom_kernel::SliceRawKernel, + int64_t, + float, + double) {} diff --git a/backends/intel_gpu/kernels/softmax_kernel.cc b/backends/intel_gpu/kernels/softmax_kernel.cc new file mode 100644 index 000000000..2dd722d84 --- /dev/null +++ b/backends/intel_gpu/kernels/softmax_kernel.cc @@ -0,0 +1,228 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/dnn_support.hpp" +#include "kernels/phi_funcs.h" +#include "paddle/phi/capi/all.h" +namespace custom_kernel { + +template +T ValueClip(const T& x) { + const T kThreshold = static_cast(-64.); + return x < kThreshold ? kThreshold : x; +} + +template +void Softmax(int axis_dim, const T* in, T* out, size_t M, size_t N) { + auto remain = N / axis_dim; + + for (size_t i = 0; i < M; ++i) { + for (size_t k = 0; k < remain; ++k) { + T max_val = in[i * N + k]; + for (size_t j = 0; j < axis_dim; ++j) { + max_val = std::max(max_val, in[i * N + j * remain + k]); + } + + auto exps = new T[axis_dim]; + for (size_t j = 0; j < axis_dim; ++j) { + exps[j] = std::exp(ValueClip(in[i * N + j * remain + k] - max_val)); + } + + T sum = 0; + for (size_t j = 0; j < axis_dim; ++j) { + sum += exps[j]; + } + + for (size_t j = 0; j < axis_dim; ++j) { + out[i * N + j * remain + k] = exps[j] / sum; + } + delete[] exps; + } + } +} + +template +void SoftmaxGrad( + const T* out, const T* out_grad, int axis_dim, int M, int N, T* x_grad) { + int num_remain = N / axis_dim; + T* dot = new T[M * num_remain]; + for (auto i = 0; i < M; ++i) { + for (auto k = 0; k < num_remain; ++k) { + dot[i * num_remain + k] = 0; + for (auto j = 0; j < axis_dim; ++j) { + dot[i * num_remain + k] += out[i * N + j * num_remain + k] * + out_grad[i * N + j * num_remain + k]; + } + } + } + for (auto i = 0; i < M; ++i) { + for (auto j = 0; j < axis_dim; ++j) { + for (auto k = 0; k < num_remain; ++k) { + x_grad[i * N + j * num_remain + k] = + (out_grad[i * N + j * num_remain + k] - dot[i * num_remain + k]) * + out[i * N + j * num_remain + k]; + } + } + } + delete[] dot; +} + +std::shared_ptr softmax_pd = nullptr; + +template +void SoftmaxGradKernel(const phi::Context& dev_ctx, + const phi::DenseTensor& out, + const phi::DenseTensor& out_grad, + int axis, + phi::DenseTensor* x_grad) { + show_kernel("SoftmaxGradKernel()"); + const int rank = x_grad->dims().size(); + const int calc_axis = phi::funcs::CanonicalAxis(axis, rank); + int axis_dim = x_grad->dims()[calc_axis]; + + dev_ctx.template Alloc(x_grad); + if (x_grad->numel() == 0) { + return; + } + + auto* q = static_cast(const_cast(dev_ctx.stream())); + + auto eng = dnnl::sycl_interop::make_engine(q->get_device(), q->get_context()); + auto engine_stream = dnnl::sycl_interop::make_stream(eng, *q); + + dnnl::memory::dims out_dims = out.dims(); + + std::vector logical_axis(out_dims.size(), 0); + for (auto i = 0; i < logical_axis.size(); ++i) { + logical_axis[i] = i; + } + + auto strides = dnn_support::computeStrides(out_dims, logical_axis); + + auto md_out = + dnnl::memory::desc(out_dims, dnn_support::toDnnType::type, strides); + + auto md_out_grad = + dnnl::memory::desc(out_dims, dnn_support::toDnnType::type, strides); + + auto md_x_grad = + dnnl::memory::desc(out_dims, dnn_support::toDnnType::type, strides); + + auto dst_memory_p = dnnl::memory(md_out, eng, out.data()); + auto diff_dst_memory_p = dnnl::memory(md_out_grad, eng, out_grad.data()); + auto diff_src_memory_p = dnnl::memory(md_x_grad, eng, x_grad->data()); + + auto bwd_desc = dnnl::softmax_backward::desc(md_out_grad, md_out, calc_axis); + auto bwd_pd_ = + dnnl::softmax_backward::primitive_desc(bwd_desc, eng, *softmax_pd); + + auto softmax_bwd_p = dnnl::softmax_backward(bwd_pd_); + + std::unordered_map softmax_args; + softmax_args.insert({DNNL_ARG_DST, dst_memory_p}); + softmax_args.insert({DNNL_ARG_DIFF_DST, diff_dst_memory_p}); + softmax_args.insert({DNNL_ARG_DIFF_SRC, diff_src_memory_p}); + + softmax_bwd_p.execute(engine_stream, softmax_args); + engine_stream.wait(); +} + +template +void SoftmaxKernel(const phi::Context& ctx, + const phi::DenseTensor& x, + int axis, + phi::DenseTensor* out) { + if constexpr (std::is_same::value) { + const int rank = x.dims().size(); + const int calc_axis = phi::funcs::CanonicalAxis(axis, rank); + int axis_dim = x.dims()[calc_axis]; + show_kernel("SoftmaxKernelOneDNN() rank=" + << rank << " calc_axis=" << calc_axis << " axis_dim=" + << axis_dim << " type=" << dnn_support::type2String::name()); + + const int n = phi::funcs::SizeToAxis(calc_axis, x.dims()); + const int d = phi::funcs::SizeFromAxis(calc_axis, x.dims()); + + auto x_data = x.data(); + auto out_data = ctx.template Alloc(out); + + dnnl::memory::dims dims_src = x.dims(); + dnnl::memory::dims dims_dst = out->dims(); + + using tag = dnnl::memory::format_tag; + using dt = dnnl::memory::data_type; + auto* q = static_cast(const_cast(ctx.stream())); + + auto eng = + dnnl::sycl_interop::make_engine(q->get_device(), q->get_context()); + auto engine_stream = dnnl::sycl_interop::make_stream(eng, *q); + + std::vector logical_axis(dims_src.size(), 0); + for (auto i = 0; i < logical_axis.size(); ++i) { + logical_axis[i] = i; + } + + auto strides = dnn_support::computeStrides(dims_src, logical_axis); + + auto md_src = + dnnl::memory::desc(dims_src, dnn_support::toDnnType::type, strides); + + auto md_dst = + dnnl::memory::desc(dims_src, dnn_support::toDnnType::type, strides); + + show_debug("ComputeStrides = " << strides); + + auto mem_src = dnnl::memory(md_src, eng, x_data); + auto mem_dst = dnnl::memory(md_dst, eng, out_data); + + auto softmax_d = dnnl::softmax_forward::desc( + dnnl::prop_kind::forward_training, md_src, calc_axis); + + softmax_pd = + std::make_shared(softmax_d, eng); + + auto softmax_prim = dnnl::softmax_forward(*softmax_pd); + std::unordered_map softmax_args; + softmax_args.insert({DNNL_ARG_SRC, mem_src}); + softmax_args.insert({DNNL_ARG_DST, mem_dst}); + + // // Primitive execution. + softmax_prim.execute(engine_stream, softmax_args); + // Wait for the computation to finalize. + engine_stream.wait(); + + } else { + std::stringstream ss; + ss << "SoftMax doesn't support type=" + << dnn_support::type2String::name(); + + show_error(ss.str()); + throw std::runtime_error(ss.str()); + } +} + +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(softmax, + intel_gpu, + ALL_LAYOUT, + custom_kernel::SoftmaxKernel, + float, + double) {} + +PD_BUILD_PHI_KERNEL(softmax_grad, + intel_gpu, + ALL_LAYOUT, + custom_kernel::SoftmaxGradKernel, + float) {} diff --git a/backends/intel_gpu/kernels/transpose_kernel.cc b/backends/intel_gpu/kernels/transpose_kernel.cc new file mode 100644 index 000000000..359cbe0d0 --- /dev/null +++ b/backends/intel_gpu/kernels/transpose_kernel.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/dnn_support.hpp" +#include "kernels/phi_funcs.h" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +template +void TransposeKernelGPU(const phi::Context& ctx, + const phi::DenseTensor& x, + const std::vector& axis, + phi::DenseTensor* out) { + show_kernel("TransposeKernelGPU "); + using tag = dnnl::memory::format_tag; + using dt = dnnl::memory::data_type; + auto* q = static_cast(const_cast(ctx.stream())); + + auto eng = dnnl::sycl_interop::make_engine(q->get_device(), q->get_context()); + auto engine_stream = dnnl::sycl_interop::make_stream(eng, *q); + + auto x_dims = x.dims(); + auto out_dims = out->dims(); + + auto x_data = x.data(); + auto out_data = ctx.template Alloc(out); + show_debug("x{dims}=" << x.dims() << " x{rank}=" << x_dims.size() + << " out{dims}=" << out->dims() << " axis=" << axis); + if (out->numel() == 0) { + return; + } + auto rank = x_dims.size(); + + if (rank == 1) { + auto total_cpy_bytes = x.numel() * sizeof(T); + q->submit( + [&](sycl::handler& h) { h.memcpy(out_data, x_data, total_cpy_bytes); }); + q->wait(); + return; + } + + PD_CHECK(axis.size() == rank, + "axis.size (%d) must be equal the rank of input (%d).", + axis.size(), + rank); + + dnnl::memory::dims dims_src = x.dims(); + dnnl::memory::dims dims_dst = out->dims(); + std::vector logical_axis(dims_src.size(), 0); + for (auto i = 0; i < logical_axis.size(); ++i) { + logical_axis[i] = i; + } + show_debug("logical_axis=" << logical_axis << " axis=" << axis); + auto md_src = + dnnl::memory::desc(dims_src, + dnn_support::toDnnType::type, + dnn_support::computeStrides(dims_src, logical_axis)); + + auto md_dst = dnnl::memory::desc(dims_src, + dnn_support::toDnnType::type, + dnn_support::computeStrides(dims_src, axis)); + + auto mem_src = dnnl::memory(md_src, eng, x_data); + auto mem_dst = dnnl::memory(md_dst, eng, out_data); + + auto reorder_pd = dnnl::reorder::primitive_desc(eng, md_src, eng, md_dst); + + // Create the primitive. + auto reorder_prim = dnnl::reorder(reorder_pd); + + std::unordered_map reorder_args; + reorder_args.insert({DNNL_ARG_SRC, mem_src}); + reorder_args.insert({DNNL_ARG_DST, mem_dst}); + + // Primitive execution: reorder with scaled sum. + reorder_prim.execute(engine_stream, reorder_args); + engine_stream.wait(); +} +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(transpose, + intel_gpu, + ALL_LAYOUT, + custom_kernel::TransposeKernelGPU, + float) {} diff --git a/backends/intel_gpu/kernels/uniform_random_kernel.cc b/backends/intel_gpu/kernels/uniform_random_kernel.cc new file mode 100644 index 000000000..54bedfac9 --- /dev/null +++ b/backends/intel_gpu/kernels/uniform_random_kernel.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "kernels/dnn_support.hpp" +#include "paddle/phi/capi/all.h" + +namespace custom_kernel { + +template +inline void UniformRealDistribution(T *data, + const int64_t &size, + const float &min, + const float &max, + std::shared_ptr engine) { + std::uniform_real_distribution dist(static_cast(min), + static_cast(max)); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(*engine); + } +} + +template +void UniformRandomRawKernel(const phi::Context &dev_ctx, + const phi::IntArray &shape, + phi::DataType dtype, + const phi::Scalar &min, + const phi::Scalar &max, + int seed, + int diag_num, + int diag_step, + float diag_val, + phi::DenseTensor *out) { + show_kernel( + "UniformRandom-SYCL type=" << dnn_support::type2String::name()); + + auto shape_data = shape.GetData(); + out->Resize(std::vector(shape_data.begin(), shape_data.end())); + auto out_data = dev_ctx.template Alloc(out); + auto numel = out->numel(); + + // // 1. CPU implement + phi::DenseTensor cpu_out; + cpu_out.Resize(std::vector(shape_data.begin(), shape_data.end())); + cpu_out.set_dtype(out->dtype()); + auto cpu_data = dev_ctx.template HostAlloc(&cpu_out); + + std::shared_ptr engine; + engine = std::make_shared(); + engine->seed(seed); + + UniformRealDistribution( + cpu_data, numel, min.to(), max.to(), engine); + if (diag_num > 0) { + PD_CHECK(numel, + (diag_num - 1) * (diag_step + 1), + "ShapeInvalid: the diagonal's elements is equal (num-1) " + "* (step-1) with num %d, step %d," + "It should be smaller than %d, but received %d", + diag_num, + diag_step, + (diag_num - 1) * (diag_step + 1), + numel); + for (int64_t i = 0; i < diag_num; ++i) { + int64_t pos = i * diag_step + i; + cpu_data[pos] = diag_val; + } + } + + // 2. CPU Copy to IntelGPU + auto *q = static_cast(dev_ctx.stream()); + q->memcpy(out_data, cpu_data, numel * sizeof(T)); +} + +template +void UniformRandomKernel(const phi::Context &dev_ctx, + const phi::IntArray &shape, + phi::DataType dtype, + // float min, + // float max, + const phi::Scalar &min, + const phi::Scalar &max, + int seed, + phi::DenseTensor *out) { + show_kernel( + "UniformRandom-SYCL type=" << dnn_support::type2String::name()); + custom_kernel::UniformRandomRawKernel( + dev_ctx, shape, dtype, min, max, seed, 0, 0, 0.0f, out); +} +} // namespace custom_kernel + +PD_BUILD_PHI_KERNEL(uniform_random_raw, + intel_gpu, + ALL_LAYOUT, + custom_kernel::UniformRandomRawKernel, + float) {} + +PD_BUILD_PHI_KERNEL(uniform_random, + intel_gpu, + ALL_LAYOUT, + custom_kernel::UniformRandomKernel, + float) {} diff --git a/backends/intel_gpu/load.sh b/backends/intel_gpu/load.sh index 6424fea25..1988332de 100644 --- a/backends/intel_gpu/load.sh +++ b/backends/intel_gpu/load.sh @@ -20,7 +20,7 @@ echo $d export PYTHONPATH=$PYTHONPATH:${PaddleDev}/python/tests/ -comp="dnnl tbb compiler dpl" +comp="dnnl tbb compiler" for item in $comp; do diff --git a/backends/intel_gpu/runtime/runtime.cc b/backends/intel_gpu/runtime/runtime.cc index 9cf39a595..47719d7f1 100644 --- a/backends/intel_gpu/runtime/runtime.cc +++ b/backends/intel_gpu/runtime/runtime.cc @@ -11,14 +11,14 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #include #include -#include #include #include -#include "./dnn_support.hpp" +#include "kernels/dnn_support.hpp" #include "paddle/phi/backends/device_ext.h" #define MEMORY_FRACTION 0.5f @@ -41,7 +41,7 @@ DeviceConfigPtr devconf; std::mutex mx; std::recursive_mutex rmux; -auto intel_match = [](const sycl::device &dev) -> bool { +auto intel_match = [](sycl::device &dev) -> bool { // NOLINT const auto name = dev.template get_info(); return (name.find("Intel(R) Graphics") != std::string::npos) ? true : false; }; @@ -52,7 +52,7 @@ struct DeviceCtx { bool _def_stream; size_t allocated_mem; size_t _dev_memory_size; - explicit DeviceCtx(sycl::device dev) + DeviceCtx(sycl::device dev) // NOLINT : _dev{std::move(dev)}, _def_stream{true}, allocated_mem{0}, @@ -80,7 +80,10 @@ struct DeviceCtx { return *(_streams[index]); } - void copy(const sycl::queue &q, void *dst, const void *src, size_t size) { + void copy(sycl::queue &q, // NOLINT + void *dst, + const void *src, + size_t size) { q.submit([&](sycl::handler &h) { h.memcpy(dst, src, size); }); q.wait(); } @@ -217,9 +220,10 @@ C_Status Allocate(const C_Device device, void **ptr, size_t size) { show_memory("request allocate size=" << size << " device=" << device->id); if (size > reg_dev[device->id].getFreeMemorySize()) { - show_error("## No free memory INTERNAL ERROR OUT OF MEMORY requested size=" - << size << " left=" << reg_dev[device->id].getFreeMemorySize() - << " ##"); + show_error( + "#### No free memory INTERNAL ERROR OUT OF MEMORY requested size=" + << size << " left=" << reg_dev[device->id].getFreeMemorySize() + << " #####"); return C_FAILED; } diff --git a/backends/intel_gpu/tests/CMakeLists.txt b/backends/intel_gpu/tests/CMakeLists.txt index f488d9153..0f98754eb 100644 --- a/backends/intel_gpu/tests/CMakeLists.txt +++ b/backends/intel_gpu/tests/CMakeLists.txt @@ -26,20 +26,14 @@ function(py_test_modules TARGET_NAME) CUSTOM_DEVICE_ROOT=${CMAKE_BINARY_DIR}/python/paddle-plugins/ PYTHONPATH=${PYTHON_SOURCE_DIR}:${PYTHON_SOURCE_DIR}/python/paddle/fluid/tests/unittests:$ENV{PYTHONPATH} ${py_test_modules_ENVS} - # python ${CMAKE_CURRENT_BINARY_DIR}/${py_test_modules_MODULES}.py + # python ${PYTHON_SOURCE_DIR}/tools/test_runner.py + # ${py_test_modules_MODULES} + python ${CMAKE_CURRENT_BINARY_DIR}/${py_test_modules_MODULES}.py WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) - message(STATUS "PYTHONPATH : $ENV{PYTHONPATH}") + if(py_test_modules_SERIAL) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) endif() endfunction() -file( - GLOB TEST_OPS - RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" - "test_*.py") -string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") - -foreach(TEST_OP ${TEST_OPS}) - py_test_modules(${TEST_OP} MODULES ${TEST_OP}) -endforeach() +add_subdirectory(unittests) diff --git a/backends/intel_gpu/tests/test_MNIST_model.py b/backends/intel_gpu/tests/test_MNIST_model.py new file mode 100644 index 000000000..003bed0b2 --- /dev/null +++ b/backends/intel_gpu/tests/test_MNIST_model.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle.optimizer import SGD + +paddle.set_device("intel_gpu") + +BATCH_SIZE = 64 + + +class MnistDataset(paddle.vision.datasets.MNIST): + def __init__(self, mode, return_label=True): + super(MnistDataset, self).__init__(mode=mode) + self.return_label = return_label + + def __getitem__(self, idx): + img = np.reshape(self.images[idx], [1, 28, 28]) + img = img / 255.0 * 2.0 - 1.0 + if self.return_label: + return img, np.array(self.labels[idx]).astype("int") + return (img,) + + def __len__(self): + return len(self.images) + + +train_reader = paddle.io.DataLoader( + MnistDataset(mode="train"), batch_size=BATCH_SIZE, drop_last=True +) +test_reader = paddle.io.DataLoader( + MnistDataset(mode="test"), batch_size=BATCH_SIZE, drop_last=True +) + + +class MNIST(paddle.nn.Layer): + def __init__(self): + super(MNIST, self).__init__() + self.shape = 1 * 28 * 28 + self.size = 10 + self.output_weight = self.create_parameter([self.shape, self.size]) + self.accuracy = paddle.metric.Accuracy() + + def forward(self, inputs, label=None): + x = paddle.reshape(inputs, shape=[-1, self.shape]) + x = paddle.matmul(x, self.output_weight) + x = paddle.nn.functional.softmax(x) + if label is not None: + self.accuracy.reset() + correct = self.accuracy.compute(x, label) + self.accuracy.update(correct) + acc = self.accuracy.accumulate() + return x, acc + else: + return x + + +mnist = MNIST() +sgd = SGD(learning_rate=0.01, parameters=mnist.parameters()) + +epoch_num = 5 +for epoch in range(epoch_num): + for batch_id, data in enumerate(train_reader()): + img = data[0] + label = data[1] + pred, acc = mnist(img, label) + avg_loss = paddle.nn.functional.cross_entropy(pred, label) + avg_loss.backward() + sgd.step() + sgd.clear_grad() + + if batch_id % 1 == 0: + print( + "Epoch {} step {}, Loss = {:}, Accuracy = {:}".format( + epoch, batch_id, avg_loss.numpy(), acc + ) + ) +model_dict = mnist.state_dict() +paddle.save(model_dict, "mnist.pdparams") diff --git a/backends/intel_gpu/tests/unittests/CMakeLists.txt b/backends/intel_gpu/tests/unittests/CMakeLists.txt new file mode 100644 index 000000000..b20535026 --- /dev/null +++ b/backends/intel_gpu/tests/unittests/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License + +file( + GLOB TEST_OPS + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) +endforeach() diff --git a/backends/intel_gpu/tests/unittests/test_argsort_op.py b/backends/intel_gpu/tests/unittests/test_argsort_op.py new file mode 100644 index 000000000..70dfeb3e1 --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_argsort_op.py @@ -0,0 +1,439 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import numpy as np +import six +import paddle.fluid.core as core + +from paddle.fluid.framework import Program, grad_var_name +from paddle.fluid.executor import Executor + +paddle.enable_static() + +np.random.seed(123) + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +OpTest._get_places = get_places + + +class PyArgsort(object): + def __init__(self, input_shape, axis, descending, dtype): + self.x = np.random.random(input_shape).astype(dtype) + self.label = np.random.random(input_shape).astype(dtype) + if axis < 0: + self.axis = axis + len(self.x.shape) + else: + self.axis = axis + self.descending = descending + + def forward(self): + if self.descending: + self.indices = np.flip( + np.argsort(self.x, kind="quicksort", axis=self.axis), self.axis + ) + self.sorted_x = np.flip( + np.sort(self.x, kind="quicksort", axis=self.axis), self.axis + ) + else: + self.indices = np.argsort(self.x, kind="quicksort", axis=self.axis) + self.sorted_x = np.sort(self.x, kind="quicksort", axis=self.axis) + self.loss = self.sorted_x * self.label + self.loss = np.sum(self.loss) + out = ( + np.array(self.indices, dtype=self.indices.dtype), + np.array(self.sorted_x, dtype=self.sorted_x.dtype), + np.array([self.loss], dtype=self.loss.dtype), + ) + return out + + +def create_tensor(np_data, place): + tensor = core.LoDTensor() + tensor.set(np_data, place) + return tensor + + +class TestArgsortOpCPU(unittest.TestCase): + def setup_program(self): + self.main_program = Program() + self.startup_program = Program() + self.init_place() + + def setUp(self): + self.init_axis() + self.init_datatype() + self.init_direction() + self.init_inputshape() + + self.setup_program() + self.feed_data_field = {"x", "label"} + self.grad_data_field = {"x"} + + self.py_argsort = PyArgsort( + self.input_shape, self.axis, self.descending, self.dtype + ) + + with fluid.program_guard(self.main_program, self.startup_program): + x = fluid.layers.data(name="x", shape=self.input_shape, dtype=self.dtype) + x.stop_gradient = False + label = fluid.layers.data( + name="label", shape=self.input_shape, dtype=self.dtype + ) + self.sorted_x, self.index = fluid.layers.argsort( + input=x, axis=self.axis, descending=self.descending + ) + self.sorted_x.stop_gradient = False + loss = fluid.layers.elementwise_mul(self.sorted_x, label) + self.loss = fluid.layers.reduce_sum(loss) + + def forward(self): + self.feed_map = { + x: create_tensor(getattr(self.py_argsort, x), self.place) + for x in self.feed_data_field + } + exe = Executor(self.place) + out = exe.run( + self.main_program, + feed=self.feed_map, + fetch_list=[self.index, self.sorted_x, self.loss], + ) + return out + + def backward(self): + self.feed_map = { + x: create_tensor(getattr(self.py_argsort, x), self.place) + for x in self.feed_data_field + } + fetch_list = [ + self.main_program.global_block().var(grad_var_name(x)) + for x in self.grad_data_field + ] + exe = Executor(self.place) + out = exe.run( + self.main_program, + feed=self.feed_map, + fetch_list=fetch_list, + return_numpy=False, + ) + return out + + # def test_backward(self, numeric_grad_delta=1e-5, max_relative_error=1e-7): + # self.check_forward() + + # with fluid.program_guard(self.main_program, self.startup_program): + # append_backward(self.loss) + + # ana_grad = [np.array(x) for x in self.backward()] + + # num_grad = self.get_numerical_gradient(delta=numeric_grad_delta) + # self.assert_is_close( + # num_grad, + # ana_grad, + # 'x', + # max_relative_error=max_relative_error, + # msg_prefix="Gradient Check On %s" % str(self.place)) + + def check_forward(self): + pd_outputs = self.forward() + py_outputs = self.py_argsort.forward() + for pd_output, py_output in zip(pd_outputs, py_outputs): + self.assertEqual(pd_output.shape, py_output.shape) + self.assertTrue(np.allclose(pd_output, py_output, atol=0, equal_nan=False)) + + def get_numerical_gradient(self, delta=1e-7): + if self.dtype == "float16": + delta = np.array(delta).astype(np.float16) + feed_list = [getattr(self.py_argsort, x) for x in self.grad_data_field] + grad_list = [np.zeros_like(x) for x in feed_list] + for feed, grad in zip(feed_list, grad_list): + for f, g in np.nditer([feed, grad], op_flags=["readwrite"]): + o = float(f) + f[...] = o + delta + y_pos = self.forward()[2] + + f[...] = o - delta + y_neg = self.forward()[2] + + f[...] = o + dout_dfeed = (y_pos - y_neg) / (delta * 2) + g[...] = dout_dfeed[0] + + return grad_list + + def assert_is_close( + self, numeric_grads, analytic_grads, names, max_relative_error, msg_prefix + ): + for a, b, name in six.moves.zip(numeric_grads, analytic_grads, names): + abs_a = np.abs(a) + abs_a[abs_a < 1e-3] = 1 + + diff_mat = np.abs(a - b) / abs_a + max_diff = np.max(diff_mat) + + def err_msg(): + offset = np.argmax(diff_mat > max_relative_error) + return ( + "%s error, %s variable %s max gradient diff %f over limit %f, " + "the first error element is %d, expected %f, but got %f." + ) % ( + "argsort", + msg_prefix, + name, + max_diff, + max_relative_error, + offset, + a.flatten()[offset], + b.flatten()[offset], + ) + + self.assertLessEqual(max_diff, max_relative_error, err_msg()) + + def init_axis(self): + self.axis = -1 + + def init_datatype(self): + self.dtype = "float32" + + def init_direction(self): + self.descending = False + + def init_inputshape(self): + self.input_shape = (2, 2, 2, 2, 3) + + def init_place(self): + self.place = core.CustomPlace("intel_gpu", 0) + + +class TestArgsortOpAxis0CPU(TestArgsortOpCPU): + def init_axis(self): + self.axis = 0 + + +class TestArgsortOpAxis1CPU(TestArgsortOpCPU): + def init_axis(self): + self.axis = 1 + + +class TestArgsortOpAxis2CPU(TestArgsortOpCPU): + def init_axis(self): + self.axis = 2 + + +class TestArgsortOpAxisNeg1CPU(TestArgsortOpCPU): + def init_axis(self): + self.axis = -1 + + +class TestArgsortOpAxisNeg2CPU(TestArgsortOpCPU): + def init_axis(self): + self.axis = -2 + + +class TestArgsortOpDescendingAxisCPU(TestArgsortOpCPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis0CPU(TestArgsortOpAxis0CPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis1CPU(TestArgsortOpAxis1CPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxis2CPU(TestArgsortOpAxis2CPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg1CPU(TestArgsortOpAxisNeg1CPU): + def init_direction(self): + self.descending = True + + +class TestArgsortOpDescendingAxisNeg2CPU(TestArgsortOpAxisNeg2CPU): + def init_direction(self): + self.descending = True + + +class TestArgsortErrorOnCPU(unittest.TestCase): + def setUp(self): + self.place = core.CustomPlace("intel_gpu", 0) + + def test_error(self): + def test_fluid_var_type(): + with fluid.program_guard(fluid.Program()): + x = [1] + output = fluid.layers.argsort(input=x) + self.assertRaises(TypeError, test_fluid_var_type) + + def test_paddle_var_type(): + with fluid.program_guard(fluid.Program()): + x = [1] + output = paddle.argsort(input=x) + self.assertRaises(TypeError, test_paddle_var_type) + + +class TestArgsort(unittest.TestCase): + def init(self): + self.input_shape = [ + 10, + ] + self.axis = 0 + + def setUp(self): + self.init() + self.place = core.CustomPlace("intel_gpu", 0) + self.data = np.random.rand(*self.input_shape).astype("float32") + + def test_api(self): + paddle.enable_static() + with fluid.program_guard(fluid.Program()): + input = fluid.data(name="input", shape=self.input_shape, dtype="float32") + + output = paddle.argsort(input, axis=self.axis) + output2 = paddle.argsort(input, axis=self.axis, descending=True) + + exe = fluid.Executor(self.place) + result, result2 = exe.run( + feed={"input": self.data}, fetch_list=[output, output2] + ) # 始终有一半会被擦掉 + + np_result = np.argsort(self.data, axis=self.axis) + self.assertEqual((result == np_result).all(), True) + + np_result2 = np.argsort(-self.data, axis=self.axis).astype("int64") + self.assertEqual((result2 == np_result2).all(), True) + + +class TestArgsort2(TestArgsort): + def init(self): + self.input_shape = [10, 1] + self.axis = 0 + + +class TestArgsort3(TestArgsort): + def init(self): + self.input_shape = [1, 10] + self.axis = 1 + + +# TODO(Zhiwei35): support argsort of dims > 3 +# class TestArgsort4(TestArgsort): +# def init(self): +# self.input_shape = [2, 3, 4] +# self.axis = 1 + + +class TestArgsortImperative(unittest.TestCase): + def init(self): + self.input_shape = [ + 100, + ] + self.axis = 0 + + def setUp(self): + self.init() + self.input_data = np.random.rand(*self.input_shape) + self.place = core.CustomPlace("intel_gpu", 0) + + def test_api(self): + paddle.disable_static(self.place) + var_x = paddle.to_tensor(self.input_data) + out = paddle.argsort(var_x, axis=self.axis) + expect = np.argsort(self.input_data, axis=self.axis) + self.assertEqual((expect == out.numpy()).all(), True) + out2 = paddle.argsort(var_x, axis=self.axis, descending=True) + expect2 = np.argsort(-self.input_data, axis=self.axis) + self.assertEqual((expect2 == out2.numpy()).all(), True) + paddle.enable_static() + + +class TestArgsortImperative2(TestArgsortImperative): + def init(self): + self.input_shape = [100, 1] + self.axis = 0 + + +class TestArgsortImperative3(TestArgsortImperative): + def init(self): + self.input_shape = [1, 10000] + self.axis = 1 + + +class TestArgsortImperative4(TestArgsortImperative): + def init(self): + self.input_shape = [17, 60] + self.axis = 1 + + +class TestArgsortImperative5(TestArgsortImperative): + def init(self): + self.input_shape = [60, 70] + self.axis = 1 + + +class TestArgsortImperative6(TestArgsortImperative): + def init(self): + self.input_shape = [10, 70] + self.axis = 1 + + +class TestArgsortLeNetCase(TestArgsortImperative): + def init(self): + self.input_shape = [64, 10] + self.axis = 1 + + +# TODO(Zhiwei35): support argsort of dims > 3 +# class TestArgsortImperative4(TestArgsortImperative): +# def init(self): +# self.input_shape = [2, 3, 4] +# self.axis = 1 + +# TODO(Zhiwei35): support argsort with NaN +# class TestArgsortWithInputNaN(unittest.TestCase): +# def init(self): +# self.axis = 0 + +# def setUp(self): +# self.init() +# self.input_data = np.array([1.0, np.nan, 3.0, 2.0]) +# self.place = core.CustomPlace("intel_gpu", 0) + +# def test_api(self): +# paddle.disable_static(self.place) +# var_x = paddle.to_tensor(self.input_data) +# out = paddle.argsort(var_x, axis=self.axis) +# self.assertEqual((out.numpy() == np.array([0, 3, 2, 1])).all(), True) +# out = paddle.argsort(var_x, axis=self.axis, descending=True) +# self.assertEqual((out.numpy() == np.array([1, 2, 3, 0])).all(), True) +# paddle.enable_static() + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_assign_value_op.py b/backends/intel_gpu/tests/unittests/test_assign_value_op.py new file mode 100644 index 000000000..8dc057f30 --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_assign_value_op.py @@ -0,0 +1,110 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import op_test +import paddle +import paddle.fluid as fluid +import paddle.fluid.framework as framework +import paddle.fluid.layers as layers +from op_test import OpTest + +paddle.enable_static() + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +OpTest._get_places = get_places + + +class TestAssignValueOp(op_test.OpTest): + def setUp(self): + self.op_type = "assign_value" + self.inputs = {} + self.attrs = {} + self.init_data() + self.attrs["shape"] = self.value.shape + self.attrs["dtype"] = framework.convert_np_dtype_to_dtype_(self.value.dtype) + self.outputs = {"Out": self.value} + + def init_data(self): + self.value = np.random.random(size=(2, 5)).astype(np.float32) + self.attrs["fp32_values"] = [float(v) for v in self.value.flat] + + def test_forward(self): + self.check_output() + + +class TestAssignValueOp2(TestAssignValueOp): + def init_data(self): + self.value = np.random.random(size=(2, 5)).astype(np.int32) + self.attrs["int32_values"] = [int(v) for v in self.value.flat] + + +class TestAssignValueOp3(TestAssignValueOp): + def init_data(self): + self.value = np.random.random(size=(2, 5)).astype(np.int64) + self.attrs["int64_values"] = [int(v) for v in self.value.flat] + + +class TestAssignValueOp4(TestAssignValueOp): + def init_data(self): + self.value = np.random.choice(a=[False, True], size=(2, 5)).astype(bool) + self.attrs["bool_values"] = [int(v) for v in self.value.flat] + + +class TestAssignApi(unittest.TestCase): + def setUp(self): + self.init_dtype() + self.value = (-100 + 200 * np.random.random(size=(2, 5))).astype(self.dtype) + self.place = paddle.CustomPlace("intel_gpu", 0) + + def init_dtype(self): + self.dtype = "float32" + + def test_assign(self): + main_program = fluid.Program() + with fluid.program_guard(main_program): + x = layers.create_tensor(dtype=self.dtype) + layers.assign(input=self.value, output=x) + + exe = fluid.Executor(self.place) + [fetched_x] = exe.run(main_program, feed={}, fetch_list=[x]) + np.testing.assert_array_equal(fetched_x, self.value) + self.assertEqual(fetched_x.dtype, self.value.dtype) + + +class TestAssignApi2(TestAssignApi): + def init_dtype(self): + self.dtype = "int32" + + +class TestAssignApi3(TestAssignApi): + def init_dtype(self): + self.dtype = "int64" + + +class TestAssignApi4(TestAssignApi): + def init_dtype(self): + self.dtype = bool + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_cast_op.py b/backends/intel_gpu/tests/unittests/test_cast_op.py new file mode 100644 index 000000000..7a3802ede --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_cast_op.py @@ -0,0 +1,132 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from op_test import OpTest, convert_uint16_to_float + +paddle.enable_static() + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +OpTest._get_places = get_places + + +class TestCastOpFp32ToFp64(OpTest): + def setUp(self): + ipt = np.random.random(size=[10, 10]) + self.inputs = {"X": ipt.astype("float32")} + self.outputs = {"Out": ipt.astype("float64")} + self.attrs = { + "in_dtype": int(core.VarDesc.VarType.FP32), + "out_dtype": int(core.VarDesc.VarType.FP64), + } + self.op_type = "cast" + + def test_check_output(self): + self.check_output() + + def test_grad(self): + self.check_grad(["X"], ["Out"]) + + +class TestCastOpFp16ToFp32(OpTest): + def setUp(self): + ipt = np.random.random(size=[10, 10]) + self.inputs = {"X": ipt.astype("float16")} + self.outputs = {"Out": ipt.astype("float32")} + self.attrs = { + "in_dtype": int(core.VarDesc.VarType.FP16), + "out_dtype": int(core.VarDesc.VarType.FP32), + } + self.op_type = "cast" + self.__class__.no_need_check_grad = True + + def test_check_output(self): + self.check_output(atol=1e-3) + + +# TODO(Zhiwei35): the case's expected dispatch to phi kernels, but fluid kernels +# class TestCastOpFp32ToFp16(OpTest): +# def setUp(self): +# ipt = np.random.random(size=[10, 10]) +# self.inputs = {'X': ipt.astype('float32')} +# self.outputs = {'Out': ipt.astype('float16')} +# self.attrs = { +# 'in_dtype': int(core.VarDesc.VarType.FP32), +# 'out_dtype': int(core.VarDesc.VarType.FP16) +# } +# self.op_type = 'cast' +# self.__class__.no_need_check_grad = True + +# def test_check_output(self): +# self.check_output(atol=1e-3) + + +class TestCastOpBf16ToFp32(OpTest): + def setUp(self): + ipt = np.array(np.random.randint(10, size=[10, 10])).astype("uint16") + self.inputs = {"X": ipt} + self.outputs = {"Out": convert_uint16_to_float(ipt)} + self.attrs = { + "in_dtype": int(core.VarDesc.VarType.BF16), + "out_dtype": int(core.VarDesc.VarType.FP32), + } + self.op_type = "cast" + self.__class__.no_need_check_grad = True + + def test_check_output(self): + self.check_output() + + +# TODO(Zhiwei35): the case's expected dispatch to phi kernels, but fluid kernels +# class TestCastOpFp32ToBf16(OpTest): +# def setUp(self): +# ipt = np.random.random(size=[10, 10]).astype('float32') +# self.inputs = {'X': ipt} +# self.outputs = {'Out': convert_float_to_uint16(ipt)} +# self.attrs = { +# 'in_dtype': int(core.VarDesc.VarType.FP32), +# 'out_dtype': int(core.VarDesc.VarType.BF16) +# } +# self.op_type = 'cast' +# self.__class__.no_need_check_grad = True + +# def test_check_output(self): +# self.check_output() + + +class TestCastOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of cast_op must be Variable. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CustomPlace("intel_gpu", 0) + ) + self.assertRaises(TypeError, fluid.layers.cast, x1, "int32") + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_compare_op.py b/backends/intel_gpu/tests/unittests/test_compare_op.py new file mode 100755 index 000000000..31866d167 --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_compare_op.py @@ -0,0 +1,308 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import op_test +from op_test import OpTest +import unittest +import numpy +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +paddle.enable_static() +OpTest._get_places = get_places + + +def create_test_class(op_type, typename, callback): + class Cls(op_test.OpTest): + def setUp(self): + a = numpy.random.random(size=(10, 7)).astype(typename) + b = numpy.random.random(size=(10, 7)).astype(typename) + c = callback(a, b) + self.python_api = eval("paddle." + op_type) + self.inputs = {"X": a, "Y": b} + self.outputs = {"Out": c} + self.op_type = op_type + + def test_output(self): + self.check_output(check_eager=False) + + def test_errors(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = fluid.layers.data(name="x", shape=[2], dtype="int32") + y = fluid.layers.data(name="y", shape=[2], dtype="int32") + a = fluid.layers.data(name="a", shape=[2], dtype="int16") + if self.op_type == "less_than": + self.assertRaises( + TypeError, fluid.layers.less_than, x=x, y=y, force_cpu=1 + ) + op = eval("fluid.layers.%s" % self.op_type) + self.assertRaises(TypeError, op, x=x, y=y, cond=1) + self.assertRaises(TypeError, op, x=x, y=a) + self.assertRaises(TypeError, op, x=a, y=y) + + cls_name = "{0}_{1}".format(op_type, typename) + Cls.__name__ = cls_name + globals()[cls_name] = Cls + + +for _type_name in {"float32", "float64", "int32", "int64"}: + if _type_name == "float64" and core.is_compiled_with_rocm(): + _type_name = "float32" + + create_test_class("less_than", _type_name, lambda _a, _b: _a < _b) + create_test_class("less_equal", _type_name, lambda _a, _b: _a <= _b) + create_test_class("greater_than", _type_name, lambda _a, _b: _a > _b) + create_test_class("greater_equal", _type_name, lambda _a, _b: _a >= _b) + create_test_class("equal", _type_name, lambda _a, _b: _a == _b) + create_test_class("not_equal", _type_name, lambda _a, _b: _a != _b) + + +def create_paddle_case(op_type, callback): + class PaddleCls(unittest.TestCase): + def setUp(self): + self.op_type = op_type + self.input_x = np.array([1, 2, 3, 4]).astype(np.int64) + self.input_y = np.array([1, 3, 2, 4]).astype(np.int64) + self.real_result = callback(self.input_x, self.input_y) + self.place = fluid.CustomPlace("intel_gpu", 0) + if core.is_compiled_with_cuda(): + self.place = paddle.CustomPlace("intel_gpu", 0) + + def test_api(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = fluid.data(name="x", shape=[4], dtype="int64") + y = fluid.data(name="y", shape=[4], dtype="int64") + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = fluid.Executor(self.place) + (res,) = exe.run( + feed={"x": self.input_x, "y": self.input_y}, fetch_list=[out] + ) + self.assertEqual((res == self.real_result).all(), True) + + def test_api_float(self): + if self.op_type == "equal": + paddle.enable_static() + with program_guard(Program(), Program()): + x = fluid.data(name="x", shape=[4], dtype="int64") + y = fluid.data(name="y", shape=[1], dtype="int64") + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = fluid.Executor(self.place) + (res,) = exe.run( + feed={"x": self.input_x, "y": 1.0}, fetch_list=[out] + ) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((res == self.real_result).all(), True) + + def test_dynamic_api(self): + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + x = paddle.to_tensor(self.input_x) + y = paddle.to_tensor(self.input_y) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + + def test_dynamic_api_int(self): + if self.op_type == "equal": + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, 1) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + + def test_dynamic_api_float(self): + if self.op_type == "equal": + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, 1.0) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + + def test_not_equal(self): + if self.op_type == "not_equal": + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + x = paddle.to_tensor(np.array([1.2e-8, 2, 2, 1]), dtype="float32") + y = paddle.to_tensor(np.array([1.1e-8, 2, 2, 1]), dtype="float32") + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + + def test_assert(self): + def test_dynamic_api_string(self): + if self.op_type == "equal": + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, "1.0") + paddle.enable_static() + + self.assertRaises(TypeError, test_dynamic_api_string) + + def test_dynamic_api_bool(self): + if self.op_type == "equal": + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, True) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + + def test_broadcast_api_1_float(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name="x", shape=[1, 2, 1, 3], dtype="float32") + y = paddle.static.data(name="y", shape=[1, 2, 3], dtype="float32") + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.arange(1, 7).reshape((1, 2, 1, 3)).astype(np.float32) + input_y = np.arange(0, 6).reshape((1, 2, 3)).astype(np.float32) + real_result = callback(input_x, input_y) + (res,) = exe.run(feed={"x": input_x, "y": input_y}, fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + + def test_broadcast_api_2_float(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name="x", shape=[1, 2, 3], dtype="float32") + y = paddle.static.data(name="y", shape=[1, 2, 1, 3], dtype="float32") + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.arange(0, 6).reshape((1, 2, 3)).astype(np.float32) + input_y = np.arange(1, 7).reshape((1, 2, 1, 3)).astype(np.float32) + real_result = callback(input_x, input_y) + (res,) = exe.run(feed={"x": input_x, "y": input_y}, fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + + def test_broadcast_api_3_float(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name="x", shape=[5], dtype="float32") + y = paddle.static.data(name="y", shape=[3, 1], dtype="float32") + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.arange(0, 5).reshape((5)).astype(np.float32) + input_y = np.array([5, 3, 2]).reshape((3, 1)).astype(np.float32) + real_result = callback(input_x, input_y) + (res,) = exe.run(feed={"x": input_x, "y": input_y}, fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + + def test_attr_name(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = fluid.layers.data(name="x", shape=[4], dtype="int32") + y = fluid.layers.data(name="y", shape=[4], dtype="int32") + op = eval("paddle.%s" % (self.op_type)) + out = op(x=x, y=y, name="name_%s" % (self.op_type)) + self.assertEqual("name_%s" % (self.op_type) in out.name, True) + + cls_name = "TestCase_{}".format(op_type) + PaddleCls.__name__ = cls_name + globals()[cls_name] = PaddleCls + + +# TODO Broadcast in kernel other than float +create_paddle_case("less_than", lambda _a, _b: _a < _b) +create_paddle_case("less_equal", lambda _a, _b: _a <= _b) +create_paddle_case("greater_than", lambda _a, _b: _a > _b) +create_paddle_case("greater_equal", lambda _a, _b: _a >= _b) +create_paddle_case("equal", lambda _a, _b: _a == _b) +create_paddle_case("not_equal", lambda _a, _b: _a != _b) + + +class TestCompareOpError(unittest.TestCase): + def test_errors(self): + paddle.enable_static() + with program_guard(Program(), Program()): + # The input x and y of compare_op must be Variable. + x = fluid.layers.data(name="x", shape=[1], dtype="float32") + y = fluid.create_lod_tensor( + numpy.array([[-1]]), [[1]], fluid.CustomPlace("intel_gpu", 0) + ) + self.assertRaises(TypeError, fluid.layers.greater_equal, x, y) + + +class API_TestElementwise_Equal(unittest.TestCase): + def test_api(self): + paddle.enable_static() + with fluid.program_guard(fluid.Program(), fluid.Program()): + label = fluid.layers.assign(np.array([3, 3], dtype="int32")) + limit = fluid.layers.assign(np.array([3, 2], dtype="int32")) + out = paddle.equal(x=label, y=limit) + place = fluid.CustomPlace("intel_gpu", 0) + exe = fluid.Executor(place) + (res,) = exe.run(fetch_list=[out]) + self.assertEqual((res == np.array([True, False])).all(), True) + + with fluid.program_guard(fluid.Program(), fluid.Program()): + label = fluid.layers.assign(np.array([3, 3], dtype="int32")) + limit = fluid.layers.assign(np.array([3, 3], dtype="int32")) + out = paddle.equal(x=label, y=limit) + place = fluid.CustomPlace("intel_gpu", 0) + exe = fluid.Executor(place) + (res,) = exe.run(fetch_list=[out]) + self.assertEqual((res == np.array([True, True])).all(), True) + + +class TestCompareOpPlace(unittest.TestCase): + def test_place_1(self): + paddle.enable_static() + place = paddle.CustomPlace("intel_gpu", 0) + if core.is_compiled_with_cuda(): + place = paddle.CustomPlace("intel_gpu", 0) + label = fluid.layers.assign(np.array([3, 3], dtype="int32")) + limit = fluid.layers.assign(np.array([3, 2], dtype="int32")) + out = fluid.layers.less_than(label, limit, force_cpu=True) + exe = fluid.Executor(place) + (res,) = exe.run(fetch_list=[out]) + self.assertEqual((res == np.array([False, False])).all(), True) + + def test_place_2(self): + place = paddle.CustomPlace("intel_gpu", 0) + data_place = place + if core.is_compiled_with_cuda(): + place = paddle.CustomPlace("intel_gpu", 0) + data_place = paddle.CustomPlace("intel_gpu", 0) + paddle.disable_static(place) + data = np.array([9], dtype="int64") + data_tensor = paddle.to_tensor(data, place=data_place) + result = data_tensor == 0 + self.assertEqual((result.numpy() == np.array([False])).all(), True) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_elementwise_mul_op.py b/backends/intel_gpu/tests/unittests/test_elementwise_mul_op.py new file mode 100644 index 000000000..80f3caab0 --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_elementwise_mul_op.py @@ -0,0 +1,240 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +from op_test import OpTest, skip_check_grad_ci + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +OpTest._get_places = get_places + + +class ElementwiseMulOp(OpTest): + def init_kernel_type(self): + self.use_mkldnn = False + + def setUp(self): + self.op_type = "elementwise_mul" + self.dtype = np.float32 + self.axis = -1 + self.init_dtype() + self.init_input_output() + self.init_kernel_type() + self.init_axis() + + self.inputs = { + "X": OpTest.np_dtype_to_fluid_dtype(self.x), + "Y": OpTest.np_dtype_to_fluid_dtype(self.y), + } + self.outputs = {"Out": self.out} + self.attrs = {"axis": self.axis, "use_mkldnn": self.use_mkldnn} + + def test_check_output(self): + self.check_output(check_dygraph=(self.use_mkldnn is False)) + + def test_check_grad_normal(self): + self.check_grad(["X", "Y"], "Out", check_dygraph=(self.use_mkldnn is False)) + + def test_check_grad_ignore_x(self): + self.check_grad( + ["Y"], "Out", no_grad_set=set("X"), check_dygraph=(self.use_mkldnn is False) + ) + + def test_check_grad_ignore_y(self): + self.check_grad( + ["X"], "Out", no_grad_set=set("Y"), check_dygraph=(self.use_mkldnn is False) + ) + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.multiply(self.x, self.y) + + def init_dtype(self): + pass + + def init_axis(self): + pass + + +@skip_check_grad_ci(reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestElementwiseMulOp_scalar(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + "X": np.random.rand(10, 3, 4).astype(np.float32), + "Y": np.random.rand(1).astype(np.float32), + } + self.outputs = {"Out": self.inputs["X"] * self.inputs["Y"]} + self.init_kernel_type() + + +class TestElementwiseMulOp_Vector(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + "X": np.random.random((100,)).astype("float32"), + "Y": np.random.random((100,)).astype("float32"), + } + self.outputs = {"Out": np.multiply(self.inputs["X"], self.inputs["Y"])} + self.init_kernel_type() + + +class TestElementwiseMulOp_broadcast_0(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.rand(100, 2, 3).astype(self.dtype) + self.y = np.random.rand(100).astype(self.dtype) + self.out = self.x * self.y.reshape(100, 1, 1) + + def init_axis(self): + self.axis = 0 + + +class TestElementwiseMulOp_broadcast_1(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + "X": np.random.rand(2, 100, 3).astype(np.float32), + "Y": np.random.rand(100).astype(np.float32), + } + + self.attrs = {"axis": 1} + self.outputs = {"Out": self.inputs["X"] * self.inputs["Y"].reshape(1, 100, 1)} + self.init_kernel_type() + + +class TestElementwiseMulOp_broadcast_2(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + "X": np.random.rand(2, 3, 100).astype(np.float32), + "Y": np.random.rand(100).astype(np.float32), + } + + self.outputs = {"Out": self.inputs["X"] * self.inputs["Y"].reshape(1, 1, 100)} + self.init_kernel_type() + + +class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + "X": np.random.rand(2, 10, 12, 3).astype(np.float32), + "Y": np.random.rand(10, 12).astype(np.float32), + } + + self.attrs = {"axis": 1} + self.outputs = { + "Out": self.inputs["X"] * self.inputs["Y"].reshape(1, 10, 12, 1) + } + self.init_kernel_type() + + +class TestElementwiseMulOp_broadcast_4(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + "X": np.random.rand(10, 2, 11).astype(np.float32), + "Y": np.random.rand(10, 1, 11).astype(np.float32), + } + self.outputs = {"Out": self.inputs["X"] * self.inputs["Y"]} + self.init_kernel_type() + + +class TestElementwiseMulOp_broadcast_5(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + "X": np.random.rand(10, 4, 2, 3).astype(np.float32), + "Y": np.random.rand(10, 4, 1, 3).astype(np.float32), + } + self.outputs = {"Out": self.inputs["X"] * self.inputs["Y"]} + self.init_kernel_type() + + +class TestElementwiseMulOp_commonuse_1(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + "X": np.random.rand(2, 3, 100).astype(np.float32), + "Y": np.random.rand(1, 1, 100).astype(np.float32), + } + self.outputs = {"Out": self.inputs["X"] * self.inputs["Y"]} + self.init_kernel_type() + + +class TestElementwiseMulOp_commonuse_2(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + "X": np.random.rand(30, 3, 1, 5).astype(np.float32), + "Y": np.random.rand(30, 1, 4, 1).astype(np.float32), + } + self.outputs = {"Out": self.inputs["X"] * self.inputs["Y"]} + self.init_kernel_type() + + +class TestElementwiseMulOp_xsize_lessthan_ysize(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + "X": np.random.rand(10, 10).astype(np.float32), + "Y": np.random.rand(2, 2, 10, 10).astype(np.float32), + } + + self.attrs = {"axis": 2} + + self.outputs = { + "Out": self.inputs["X"].reshape(1, 1, 10, 10) * self.inputs["Y"] + } + self.init_kernel_type() + + +class TestElementwiseMulOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # the input of elementwise_mul must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), + [[1, 1, 1, 1]], + fluid.CustomPlace("intel_gpu", 0), + ) + y1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), + [[1, 1, 1, 1]], + fluid.CustomPlace("intel_gpu", 0), + ) + self.assertRaises(TypeError, fluid.layers.elementwise_mul, x1, y1) + + # the input dtype of elementwise_mul must be float16 or float32 or float32 or int32 or int64 + # float16 only can be set on GPU place + x2 = fluid.layers.data(name="x2", shape=[3, 4, 5, 6], dtype="uint8") + y2 = fluid.layers.data(name="y2", shape=[3, 4, 5, 6], dtype="uint8") + self.assertRaises(TypeError, fluid.layers.elementwise_mul, x2, y2) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_fill_constant_op.py b/backends/intel_gpu/tests/unittests/test_fill_constant_op.py new file mode 100644 index 000000000..5ccde0fcb --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_fill_constant_op.py @@ -0,0 +1,241 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + +import paddle +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import numpy as np + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +paddle.enable_static() +OpTest._get_places = get_places + + +# Situation 1: Attr(shape) is a list(without tensor) +class TestFillConstantOp1(OpTest): + def setUp(self): + """Test fill_constant op with specified value""" + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {"shape": [123, 92], "value": 3.8} + self.outputs = {"Out": np.full((123, 92), 3.8)} + + def test_check_output(self): + self.check_output() + + +class TestFillConstantOp2(OpTest): + def setUp(self): + """Test fill_constant op with default value""" + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {"shape": [123, 92]} + self.outputs = {"Out": np.full((123, 92), 0.0)} + + def test_check_output(self): + self.check_output() + + +class TestFillConstantOp3(OpTest): + def setUp(self): + """Test fill_constant op with specified int64 value""" + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {"shape": [123, 92], "value": 10000000000} + self.outputs = {"Out": np.full((123, 92), 10000000000)} + + def test_check_output(self): + self.check_output() + + +class TestFillConstantOp4(OpTest): + def setUp(self): + """Test fill_constant op with specified int value""" + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = {"shape": [123, 92], "value": 3} + self.outputs = {"Out": np.full((123, 92), 3)} + + def test_check_output(self): + self.check_output() + + +class TestFillConstantOpWithSelectedRows(unittest.TestCase): + def check_with_place(self, place): + scope = core.Scope() + # create Out Variable + out = scope.var("Out").get_selected_rows() + + # create and run fill_constant_op operator + fill_constant_op = Operator( + "fill_constant", shape=[123, 92], value=3.8, Out="Out" + ) + fill_constant_op.run(scope, place) + + # get result from Out + result_array = np.array(out.get_tensor()) + full_array = np.full((123, 92), 3.8, "float32") + + self.assertTrue(np.array_equal(result_array, full_array)) + + def test_fill_constant_with_selected_rows(self): + places = [core.CustomPlace("intel_gpu", 0)] + + for place in places: + self.check_with_place(place) + + +# Situation 2: Attr(shape) is a list(with tensor) +class TestFillConstantOp1_ShapeTensorList(OpTest): + def setUp(self): + """Test fill_constant op with specified value""" + self.op_type = "fill_constant" + self.init_data() + shape_tensor_list = [] + for index, ele in enumerate(self.shape): + shape_tensor_list.append( + ("x" + str(index), np.ones((1)).astype("int32") * ele) + ) + + self.inputs = {"ShapeTensorList": shape_tensor_list} + self.attrs = {"shape": self.infer_shape, "value": self.value} + self.outputs = {"Out": np.full(self.shape, self.value)} + + def init_data(self): + self.shape = [123, 92] + self.infer_shape = [-1, 92] + self.value = 3.8 + + def test_check_output(self): + self.check_output() + + +class TestFillConstantOp2_ShapeTensorList(OpTest): + def setUp(self): + """Test fill_constant op with default value""" + self.op_type = "fill_constant" + self.init_data() + shape_tensor_list = [] + for index, ele in enumerate(self.shape): + shape_tensor_list.append( + ("x" + str(index), np.ones((1)).astype("int32") * ele) + ) + + self.inputs = {"ShapeTensorList": shape_tensor_list} + self.attrs = {"shape": self.infer_shape} + self.outputs = {"Out": np.full(self.shape, 0.0)} + + def init_data(self): + self.shape = [123, 92] + self.infer_shape = [-1, -1] + + def test_check_output(self): + self.check_output() + + +class TestFillConstantOp3_ShapeTensorList(TestFillConstantOp1_ShapeTensorList): + def init_data(self): + self.shape = [123, 92] + self.infer_shape = [123, -1] + self.value = 10000000000 + + +class TestFillConstantOp4_ShapeTensorList(TestFillConstantOp1_ShapeTensorList): + def init_data(self): + self.shape = [123, 92] + self.infer_shape = [123, -1] + self.value = 3 + + +# Situation 3: shape is a tensor +class TestFillConstantOp1_ShapeTensor(OpTest): + def setUp(self): + """Test fill_constant op with specified value""" + self.op_type = "fill_constant" + self.init_data() + + self.inputs = {"ShapeTensor": np.array(self.shape).astype("int32")} + self.attrs = {"value": self.value} + self.outputs = {"Out": np.full(self.shape, self.value)} + + def init_data(self): + self.shape = [123, 92] + self.value = 3.8 + + def test_check_output(self): + self.check_output() + + +# Situation 4: value is a tensor +class TestFillConstantOp1_ValueTensor(OpTest): + def setUp(self): + """Test fill_constant op with specified value""" + self.op_type = "fill_constant" + self.init_data() + + self.inputs = { + "ShapeTensor": np.array(self.shape).astype("int32"), + "ValueTensor": np.array([self.value]).astype("float32"), + } + self.attrs = {"value": self.value + 1.0} + self.outputs = {"Out": np.full(self.shape, self.value)} + + def init_data(self): + self.shape = [123, 92] + self.value = 3.8 + self.dtype = np.float32 + + def test_check_output(self): + self.check_output() + + +# Situation 5: value is a tensor +class TestFillConstantOp2_ValueTensor(OpTest): + def setUp(self): + """Test fill_constant op with specified value""" + self.op_type = "fill_constant" + self.init_data() + + self.inputs = { + "ShapeTensor": np.array(self.shape).astype("int32"), + "ValueTensor": np.array([self.value]).astype("int32"), + } + self.attrs = {"value": self.value, "dtype": 2} + self.outputs = {"Out": np.full(self.shape, self.value)} + + def init_data(self): + self.shape = [123, 92] + self.value = 3 + self.dtype = np.int32 + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_mean_op.py b/backends/intel_gpu/tests/unittests/test_mean_op.py new file mode 100644 index 000000000..8dea07689 --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_mean_op.py @@ -0,0 +1,96 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +np.random.seed(10) + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +OpTest._get_places = get_places + + +def mean_wrapper(x, axis=None, keepdim=False, reduce_all=False): + if reduce_all: + return paddle.mean(x, range(len(x.shape)), keepdim) + return paddle.mean(x, axis, keepdim) + + +def reduce_mean_wrapper(x, axis=0, keepdim=False, reduce_all=False): + if reduce_all: + return paddle.mean(x, range(len(x.shape)), keepdim) + return paddle.mean(x, axis, keepdim) + + +class TestMeanOp(OpTest): + def setUp(self): + self.op_type = "mean" + self.python_api = fluid.layers.mean + self.dtype = np.float64 + self.init_dtype_type() + self.inputs = {"X": np.random.random((10, 10)).astype(self.dtype)} + self.outputs = {"Out": np.mean(self.inputs["X"])} + + def init_dtype_type(self): + pass + + def test_check_output(self): + self.check_output(check_eager=False) + + def test_checkout_grad(self): + self.check_grad(["X"], "Out", check_eager=False) + + +class TestMeanOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of mean_op must be Variable. + input1 = 12 + self.assertRaises(TypeError, fluid.layers.mean, input1) + # The input dtype of mean_op must be float16, float32, float64. + input2 = fluid.layers.data(name="input2", shape=[12, 10], dtype="int32") + self.assertRaises(TypeError, fluid.layers.mean, input2) + input3 = fluid.layers.data(name="input3", shape=[4], dtype="float16") + fluid.layers.softmax(input3) + + +def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False): + if isinstance(axis, list): + axis = tuple(axis) + if reduce_all: + axis = None + return np.mean(x, axis=axis, keepdims=keepdim) + + +def ref_reduce_mean_grad(x, axis, dtype, reduce_all=False): + if reduce_all: + axis = list(range(x.ndim)) + + shape = [x.shape[i] for i in axis] + return (1.0 / np.prod(shape) * np.ones(shape)).astype(dtype) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_memcpy_op.py b/backends/intel_gpu/tests/unittests/test_memcpy_op.py new file mode 100644 index 000000000..06ad6bf9d --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_memcpy_op.py @@ -0,0 +1,117 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest + +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +SEED = 2021 + + +class TestMemcpy_API(unittest.TestCase): + def init_config(self): + self.dtype = "float32" + self.shape = [10, 10] + + def get_prog(self): + self.init_config() + self.__class__.use_custom_device = True + paddle.enable_static() + main_program = Program() + with program_guard(main_program): + cpu_var_name = "tensor@Cpu" + intel_gpu_var_name = "tensor@intel_gpu" + cpu_var = main_program.global_block().create_var( + name=cpu_var_name, + shape=self.shape, + dtype=self.dtype, + persistable=False, + stop_gradient=True, + ) + intel_gpu_var = main_program.global_block().create_var( + name=intel_gpu_var_name, + shape=self.shape, + dtype=self.dtype, + persistable=False, + stop_gradient=True, + ) + main_program.global_block().append_op( + type="fill_constant", + outputs={"Out": intel_gpu_var_name}, + attrs={ + "shape": intel_gpu_var.shape, + "dtype": intel_gpu_var.dtype, + "value": 1.0, + }, + ) + main_program.global_block().append_op( + type="fill_constant", + outputs={"Out": cpu_var_name}, + attrs={ + "shape": cpu_var.shape, + "dtype": cpu_var.dtype, + "value": 0.0, + "place_type": 0, + }, + ) + return main_program, intel_gpu_var, cpu_var + + def test_intel_gpu_copy_to_cpu(self): + self.__class__.use_custom_device = True + main_program, intel_gpu_var, cpu_var = self.get_prog() + main_program.global_block().append_op( + type="memcpy", + inputs={"X": intel_gpu_var}, + outputs={"Out": cpu_var}, + attrs={"dst_place_type": 0}, + ) + place = paddle.CustomPlace("intel_gpu", 0) + exe = fluid.Executor(place) + intel_gpu_, cpu_ = exe.run( + main_program, feed={}, fetch_list=[intel_gpu_var.name, cpu_var.name] + ) + np.testing.assert_allclose(intel_gpu_, cpu_) + np.testing.assert_allclose(cpu_, np.ones(self.shape, dtype=self.dtype)) + + def test_cpu_copy_intel_gpu(self): + self.__class__.use_custom_device = True + main_program, intel_gpu_var, cpu_var = self.get_prog() + main_program.global_block().append_op( + type="memcpy", + inputs={"X": cpu_var}, + outputs={"Out": intel_gpu_var}, + attrs={"dst_place_type": 6}, + ) + place = paddle.CustomPlace("intel_gpu", 0) + exe = fluid.Executor(place) + intel_gpu_, cpu_ = exe.run( + main_program, feed={}, fetch_list=[intel_gpu_var.name, cpu_var.name] + ) + np.testing.assert_allclose(intel_gpu_, cpu_) + np.testing.assert_allclose(intel_gpu_, np.zeros(self.shape, dtype=self.dtype)) + + +class TestMemcpy_3D(TestMemcpy_API): + def init_config(self): + self.dtype = "float32" + self.shape = [15, 10, 5] + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_reduce_op.py b/backends/intel_gpu/tests/unittests/test_reduce_op.py new file mode 100644 index 000000000..66ea46ba4 --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_reduce_op.py @@ -0,0 +1,439 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from paddle.fluid.framework import convert_np_dtype_to_dtype_ + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +OpTest._get_places = get_places + + +# only support reduce sum with dims < 7 +class TestSumOp3D(OpTest): + def setUp(self): + self.python_api = paddle.sum + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((5, 6, 10)).astype("float32")} + self.outputs = {"Out": self.inputs["X"].sum(axis=0)} + self.attrs = {"dim": [0]} + + def test_check_output(self): + self.check_output(check_eager=False) + + def test_check_grad(self): + self.check_grad(["X"], "Out", check_eager=False) + + +class TestSumOp4D(OpTest): + def setUp(self): + self.python_api = paddle.sum + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((1, 5, 6, 10)).astype("float32")} + self.attrs = {"dim": [0]} + self.outputs = {"Out": self.inputs["X"].sum(axis=0)} + + def test_check_output(self): + self.check_output(check_eager=False) + + def test_check_grad(self): + self.check_grad(["X"], "Out", check_eager=False) + + +class TestSumOp5D(OpTest): + def setUp(self): + self.python_api = paddle.sum + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((1, 2, 5, 6, 10)).astype("float32")} + self.attrs = {"dim": [0]} + self.outputs = {"Out": self.inputs["X"].sum(axis=0)} + + def test_check_output(self): + self.check_output(check_eager=False) + + def test_check_grad(self): + self.check_grad(["X"], "Out", check_eager=False) + + +class TestSumOp6D(OpTest): + def setUp(self): + self.python_api = paddle.sum + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((1, 1, 2, 5, 6, 10)).astype("float32")} + self.attrs = {"dim": [0]} + self.outputs = {"Out": self.inputs["X"].sum(axis=0)} + + def test_check_output(self): + self.check_output(check_eager=False) + + def test_check_grad(self): + self.check_grad(["X"], "Out", check_eager=False) + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework." +) +class TestMaxOp(OpTest): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.python_api = paddle.max + self.inputs = {"X": np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {"dim": [-1]} + self.outputs = {"Out": self.inputs["X"].max(axis=tuple(self.attrs["dim"]))} + + def test_check_output(self): + self.check_output(check_eager=False) + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework." +) +class TestMinOp(OpTest): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.python_api = paddle.min + self.inputs = {"X": np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {"dim": [2]} + self.outputs = {"Out": self.inputs["X"].min(axis=tuple(self.attrs["dim"]))} + + def test_check_output(self): + self.check_output(check_eager=False) + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework." +) +class TestMin6DOp(OpTest): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.python_api = paddle.min + self.inputs = {"X": np.random.random((2, 4, 3, 5, 6, 10)).astype("float32")} + self.attrs = {"dim": [2, 4]} + self.outputs = {"Out": self.inputs["X"].min(axis=tuple(self.attrs["dim"]))} + + def test_check_output(self): + self.check_output(check_eager=False) + + +class Test1DReduce(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random(120).astype("float32")} + self.outputs = {"Out": self.inputs["X"].sum(axis=0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class Test2DReduce0(Test1DReduce): + def setUp(self): + self.op_type = "reduce_sum" + self.attrs = {"dim": [0]} + self.inputs = {"X": np.random.random((20, 10)).astype("float32")} + self.outputs = {"Out": self.inputs["X"].sum(axis=0)} + + +class Test2DReduce1(Test1DReduce): + def setUp(self): + self.op_type = "reduce_sum" + self.attrs = {"dim": [1]} + self.inputs = {"X": np.random.random((20, 10)).astype("float32")} + self.outputs = {"Out": self.inputs["X"].sum(axis=tuple(self.attrs["dim"]))} + + +class Test3DReduce0(Test1DReduce): + def setUp(self): + self.op_type = "reduce_sum" + self.attrs = {"dim": [1]} + self.inputs = {"X": np.random.random((5, 6, 7)).astype("float32")} + self.outputs = {"Out": self.inputs["X"].sum(axis=tuple(self.attrs["dim"]))} + + +class Test3DReduce1(Test1DReduce): + def setUp(self): + self.op_type = "reduce_sum" + self.attrs = {"dim": [2]} + self.inputs = {"X": np.random.random((5, 6, 7)).astype("float32")} + self.outputs = {"Out": self.inputs["X"].sum(axis=tuple(self.attrs["dim"]))} + + +class Test3DReduce2(Test1DReduce): + def setUp(self): + self.op_type = "reduce_sum" + self.attrs = {"dim": [-2]} + self.inputs = {"X": np.random.random((5, 6, 7)).astype("float32")} + self.outputs = {"Out": self.inputs["X"].sum(axis=tuple(self.attrs["dim"]))} + + +class Test3DReduce3(Test1DReduce): + def setUp(self): + self.op_type = "reduce_sum" + self.attrs = {"dim": [1, 2]} + self.inputs = {"X": np.random.random((5, 6, 7)).astype("float32")} + self.outputs = {"Out": self.inputs["X"].sum(axis=tuple(self.attrs["dim"]))} + + +class TestKeepDimReduce(Test1DReduce): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {"dim": [1], "keep_dim": True} + self.outputs = { + "Out": self.inputs["X"].sum( + axis=tuple(self.attrs["dim"]), keepdims=self.attrs["keep_dim"] + ) + } + + +@skip_check_grad_ci( + reason="reduce_anyreduce_any is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework." +) +class TestReduceMaxOpMultiAxises(OpTest): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.python_api = paddle.max + self.inputs = {"X": np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {"dim": [-2, -1]} + self.outputs = {"Out": self.inputs["X"].max(axis=tuple(self.attrs["dim"]))} + + def test_check_output(self): + self.check_output(check_eager=False) + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework." +) +class TestReduceMinOpMultiAxises(OpTest): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.python_api = paddle.min + self.inputs = {"X": np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {"dim": [1, 2]} + self.outputs = {"Out": self.inputs["X"].min(axis=tuple(self.attrs["dim"]))} + + def test_check_output(self): + self.check_output(check_eager=False) + + +class TestKeepDimReduceSumMultiAxises(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {"dim": [-2, -1], "keep_dim": True} + self.outputs = { + "Out": self.inputs["X"].sum(axis=tuple(self.attrs["dim"]), keepdims=True) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReduceSumWithDimOne(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((100, 1, 1)).astype("float32")} + self.attrs = {"dim": [1, 2], "keep_dim": True} + self.outputs = { + "Out": self.inputs["X"].sum(axis=tuple(self.attrs["dim"]), keepdims=True) + } + + def test_check_output(self): + self.check_output() + + # def test_check_grad(self): + # self.check_grad(['X'], 'Out') + + +class TestReduceSumWithNumelOne(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((100, 1)).astype("float32")} + self.attrs = {"dim": [1], "keep_dim": False} + self.outputs = { + "Out": self.inputs["X"].sum(axis=tuple(self.attrs["dim"]), keepdims=False) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReduceAll(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((100, 1, 1)).astype("float32")} + self.attrs = {"reduce_all": True, "keep_dim": False} + self.outputs = {"Out": self.inputs["X"].sum()} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class Test1DReduceWithAxes1(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random(100).astype("float32")} + self.attrs = {"dim": [0], "keep_dim": False} + self.outputs = {"Out": self.inputs["X"].sum(axis=0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReduceWithDtype(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((6, 2, 10)).astype("float64")} + self.outputs = {"Out": self.inputs["X"].sum().astype("float64")} + self.attrs = {"reduce_all": True} + self.attrs.update( + { + "in_dtype": int(convert_np_dtype_to_dtype_(np.float32)), + "out_dtype": int(convert_np_dtype_to_dtype_(np.float64)), + } + ) + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReduceWithDtype1(TestReduceWithDtype): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((6, 2, 10)).astype("float64")} + self.outputs = {"Out": self.inputs["X"].sum(axis=1)} + self.attrs = {"dim": [1]} + self.attrs.update( + { + "in_dtype": int(convert_np_dtype_to_dtype_(np.float32)), + "out_dtype": int(convert_np_dtype_to_dtype_(np.float64)), + } + ) + + +class TestReduceWithDtype2(TestReduceWithDtype): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {"X": np.random.random((6, 2, 10)).astype("float64")} + self.outputs = {"Out": self.inputs["X"].sum(axis=1, keepdims=True)} + self.attrs = {"dim": [1], "keep_dim": True} + self.attrs.update( + { + "in_dtype": int(convert_np_dtype_to_dtype_(np.float32)), + "out_dtype": int(convert_np_dtype_to_dtype_(np.float64)), + } + ) + + +class TestReduceSumOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of reduce_sum_op must be Variable. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CustomPlace("intel_gpu", 0) + ) + self.assertRaises(TypeError, fluid.layers.reduce_sum, x1) + # The input dtype of reduce_sum_op must be float32 or float64 or int32 or int64. + x2 = fluid.layers.data(name="x2", shape=[4], dtype="uint8") + self.assertRaises(TypeError, fluid.layers.reduce_sum, x2) + + +class API_TestSumOp(unittest.TestCase): + def run_static(self, shape, x_dtype, attr_axis, attr_dtype=None, np_axis=None): + if np_axis is None: + np_axis = attr_axis + + places = [fluid.CustomPlace("intel_gpu", 0)] + for place in places: + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data("data", shape=shape, dtype=x_dtype) + result_sum = paddle.sum(x=data, axis=attr_axis, dtype=attr_dtype) + + exe = fluid.Executor(place) + input_data = np.random.rand(*shape).astype(x_dtype) + (res,) = exe.run(feed={"data": input_data}, fetch_list=[result_sum]) + + self.assertTrue( + np.allclose(res, np.sum(input_data.astype(attr_dtype), axis=np_axis)) + ) + + def test_static(self): + shape = [10, 10] + axis = 1 + + self.run_static(shape, "float32", axis) + + shape = [5, 5, 5] + self.run_static(shape, "float32", (0, 1)) + self.run_static(shape, "float32", (), np_axis=(0, 1, 2)) + + def test_dygraph(self): + np_x = np.random.random([2, 3, 4]).astype("float32") + with fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + x = fluid.dygraph.to_variable(np_x) + out0 = paddle.sum(x).numpy() + out1 = paddle.sum(x, axis=0).numpy() + out2 = paddle.sum(x, axis=(0, 1)).numpy() + out3 = paddle.sum(x, axis=(0, 1, 2)).numpy() + self.assertTrue(np.allclose(out0, np.sum(np_x, axis=(0, 1, 2)), 1e-5, 1e-5)) + self.assertTrue(np.allclose(out1, np.sum(np_x, axis=0), 1e-5, 1e-5)) + self.assertTrue(np.allclose(out2, np.sum(np_x, axis=(0, 1)), 1e-5, 1e-5)) + self.assertTrue(np.allclose(out3, np.sum(np_x, axis=(0, 1, 2)), 1e-5, 1e-5)) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_reshape_op.py b/backends/intel_gpu/tests/unittests/test_reshape_op.py new file mode 100755 index 000000000..8308d8452 --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_reshape_op.py @@ -0,0 +1,352 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.static import Program, program_guard + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +paddle.enable_static() +OpTest._get_places = get_places + + +# situation 1: have shape( list, no tensor), no actual shape(Tensor) +class TestReshapeOp(OpTest): + def setUp(self): + self.init_data() + self.op_type = "reshape2" + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.attrs = {"shape": self.new_shape} + self.outputs = { + "Out": self.inputs["X"].reshape(self.infered_shape), + "XShape": np.random.random(self.ori_shape).astype("float32"), + } + + def init_data(self): + self.ori_shape = (2, 60) + self.new_shape = (12, 10) + self.infered_shape = (12, 10) + + def test_check_output(self): + self.check_output(no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReshapeOpDimInfer1(TestReshapeOp): + def init_data(self): + self.ori_shape = (5, 25) + self.new_shape = (5, -1, 5) + self.infered_shape = (5, -1, 5) + + +class TestReshapeOpDimInfer2(TestReshapeOp): + def init_data(self): + self.ori_shape = (10, 2, 6) + self.new_shape = (10, 0, 3, -1) + self.infered_shape = (10, 2, 3, -1) + + +# situation 2: have shape(list, no tensor), have actual shape(Tensor) +class TestReshapeOpWithInputShape(OpTest): + def setUp(self): + self.init_data() + self.op_type = "reshape2" + + self.inputs = { + "X": np.random.random(self.ori_shape).astype("float32"), + "Shape": np.array(self.actual_shape, dtype="int32"), + } + self.attrs = {"shape": self.new_shape} + self.outputs = { + "Out": self.inputs["X"].reshape(self.actual_shape), + "XShape": np.random.random(self.ori_shape).astype("float32"), + } + + def init_data(self): + self.ori_shape = (6, 20) + self.new_shape = (0, -1, 20) + self.actual_shape = (2, 3, 20) + + def test_check_output(self): + self.check_output(no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Situation 3: have shape(list, have tensor), no actual shape(Tensor) +class TestReshapeOp_attr_ShapeTensor(OpTest): + def setUp(self): + self.init_data() + self.op_type = "reshape2" + + shape_tensor = [] + for index, ele in enumerate(self.new_shape): + shape_tensor.append(("x" + str(index), np.ones((1)).astype("int32") * ele)) + + self.inputs = { + "X": np.random.random(self.ori_shape).astype("float32"), + "ShapeTensor": shape_tensor, + } + self.attrs = {"shape": self.shape} + self.outputs = { + "Out": self.inputs["X"].reshape(self.infered_shape), + "XShape": np.random.random(self.ori_shape).astype("float32"), + } + + def init_data(self): + self.ori_shape = (4, 25) + self.new_shape = (10, 10) + self.infered_shape = (10, 10) + self.shape = (-1, -1) + + def test_check_output(self): + self.check_output(no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReshapeOpDimInfer1_attr_ShapeTensor(TestReshapeOp_attr_ShapeTensor): + def init_data(self): + self.ori_shape = (5, 20) + self.new_shape = (5, -1, 20) + self.infered_shape = (5, -1, 20) + self.shape = (5, -1, -1) + + +class TestReshapeOpDimInfer2_attr_ShapeTensor(TestReshapeOp_attr_ShapeTensor): + def init_data(self): + self.ori_shape = (10, 2, 6) + self.new_shape = (10, 0, 3, -1) + self.infered_shape = (10, 2, 3, -1) + self.shape = (10, 0, 3, -1) + + +# Situation 4: have shape(Tensor), no actual shape(Tensor) +class TestReshapeOp_attr_OnlyShape(OpTest): + def setUp(self): + self.init_data() + self.op_type = "reshape2" + + self.inputs = { + "X": np.random.random(self.ori_shape).astype("float32"), + "Shape": np.array(self.new_shape, dtype="int32"), + } + self.attrs = {} + self.outputs = { + "Out": self.inputs["X"].reshape(self.infered_shape), + "XShape": np.random.random(self.ori_shape).astype("float32"), + } + + def init_data(self): + self.ori_shape = (4, 25) + self.new_shape = (10, 10) + self.infered_shape = (10, 10) + + def test_check_output(self): + self.check_output(no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReshapeOpDimInfer1_attr_OnlyShape(TestReshapeOp_attr_OnlyShape): + def init_data(self): + self.ori_shape = (5, 20) + self.new_shape = (5, -1, 10) + self.infered_shape = (5, -1, 10) + self.shape = (5, -1, -1) + + +class TestReshapeOpDimInfer2_attr_OnlyShape(TestReshapeOp_attr_OnlyShape): + def init_data(self): + self.ori_shape = (10, 2, 6) + self.new_shape = (10, 0, 3, -1) + self.infered_shape = (10, 2, 3, -1) + self.shape = (10, 0, 3, -1) + + +class TestReshapeOpBool(TestReshapeOp): + def setUp(self): + self.init_data() + self.op_type = "reshape2" + self.inputs = {"X": np.random.choice([True, False], size=self.ori_shape)} + self.attrs = {"shape": self.new_shape} + self.outputs = { + "Out": self.inputs["X"].reshape(self.infered_shape), + "XShape": np.random.random(self.ori_shape).astype("float32"), + } + + def test_check_grad(self): + pass + + +# Test python API +class TestReshapeAPI(unittest.TestCase): + def _set_paddle_api(self): + self.fill_constant = paddle.fluid.layers.fill_constant + self.data = paddle.static.data + self.to_tensor = paddle.to_tensor + self._executed_api() + + def _executed_api(self): + self.reshape = paddle.reshape + + def _set_fluid_api(self): + self.fill_constant = fluid.layers.fill_constant + self.data = paddle.static.data + self.reshape = fluid.layers.reshape + + def _test_api(self): + paddle.enable_static() + input = np.random.random([2, 25]).astype("float32") + shape = [2, 5, 5] + main_prog = Program() + with program_guard(main_prog, Program()): + positive_five = self.fill_constant([1], "int32", 5) + x = self.data(name="x", shape=[2, 25], dtype="float32") + + actual_shape = self.data(name="shape", shape=[3], dtype="int32") + + # situation 1: have shape( list, no tensor), no actual shape(Tensor) + out_1 = self.reshape(x, shape) + + # situation 2: have shape(list, no tensor), have actual shape(Tensor) + out_2 = fluid.layers.reshape(x, shape=shape, actual_shape=actual_shape) + + # Situation 3: have shape(list, have tensor), no actual shape(Tensor) + out_3 = self.reshape(x, shape=[positive_five, 10]) + + # Situation 4: have shape(Tensor), no actual shape(Tensor) + out_4 = self.reshape(x, shape=actual_shape) + + exe = paddle.static.Executor(place=paddle.CustomPlace("intel_gpu", 0)) + res_1, res_2, res_3, res_4 = exe.run( + main_prog, + feed={"x": input, "shape": np.array([2, 5, 5]).astype("int32")}, + fetch_list=[out_1, out_2, out_3, out_4], + ) + + assert np.array_equal(res_1, input.reshape(shape)) + assert np.array_equal(res_2, input.reshape(shape)) + assert np.array_equal(res_3, input.reshape([5, 10])) + assert np.array_equal(res_4, input.reshape(shape)) + + def test_paddle_api(self): + self._set_paddle_api() + self._test_api() + + def test_fluid_api(self): + self._set_fluid_api() + self._test_api() + + def test_imperative(self): + self._set_paddle_api() + input = np.random.random([2, 25]).astype("float32") + shape = [2, 5, 5] + with fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + x = self.to_tensor(input) + positive_five = self.fill_constant([1], "int32", 5) + + out_1 = self.reshape(x, shape) + + out_2 = self.reshape(x, shape=[positive_five, 10]) + + shape_tensor = self.to_tensor(np.array([2, 5, 5]).astype("int32")) + out_3 = self.reshape(x, shape=shape_tensor) + + assert np.array_equal(out_1.numpy(), input.reshape(shape)) + assert np.array_equal(out_2.numpy(), input.reshape([5, 10])) + assert np.array_equal(out_3.numpy(), input.reshape(shape)) + + +class TestStaticReshape_(TestReshapeAPI): + def _executed_api(self): + self.reshape = paddle.reshape_ + + def test_imperative(self): + self._set_paddle_api() + input = np.random.random([2, 25]).astype("float32") + shape = [2, 5, 5] + with fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + x = self.to_tensor(input) + positive_five = self.fill_constant([1], "int32", 5) + + out_1 = self.reshape(x, shape) + + out_2 = self.reshape(x, shape=[positive_five, 10]) + + shape_tensor = self.to_tensor(np.array([2, 5, 5]).astype("int32")) + out_3 = self.reshape(x, shape=shape_tensor) + + assert np.array_equal(out_1.numpy(), input.reshape(shape)) + assert np.array_equal(out_2.numpy(), input.reshape(shape)) + assert np.array_equal(out_3.numpy(), input.reshape(shape)) + + +class TestDygraphReshapeAPI(unittest.TestCase): + def setUp(self): + self.executed_api() + + def executed_api(self): + self.reshape = paddle.reshape + + def test_out(self): + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + input_1 = np.random.random([5, 1, 10]).astype("int32") + input = paddle.to_tensor(input_1) + output = self.reshape(x=input, shape=[5, 10]) + out_np = output.numpy() + expected_out = np.reshape(input_1, newshape=[5, 10]) + self.assertTrue(np.allclose(expected_out, out_np)) + + def test_out_uint8(self): + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + input_1 = np.random.random([5, 1, 10]).astype("uint8") + input = paddle.to_tensor(input_1) + output = self.reshape(x=input, shape=[5, 10]) + out_np = output.numpy() + expected_out = np.reshape(input_1, newshape=[5, 10]) + self.assertTrue(np.allclose(expected_out, out_np)) + + def test_out_float32(self): + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + input_1 = np.random.random([5, 1, 10]).astype("float32") + input = paddle.to_tensor(input_1) + output = self.reshape(x=input, shape=[5, 10]) + out_np = output.numpy() + expected_out = np.reshape(input_1, newshape=[5, 10]) + self.assertTrue(np.allclose(expected_out, out_np)) + + +class TestDygraphReshapeInplaceAPI(TestDygraphReshapeAPI): + def executed_api(self): + self.reshape = paddle.reshape_ + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_slice_op.py b/backends/intel_gpu/tests/unittests/test_slice_op.py new file mode 100644 index 000000000..9c03c7501 --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_slice_op.py @@ -0,0 +1,655 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid as fluid +import paddle + +paddle.enable_static() + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +OpTest._get_places = get_places + + +# Situation 1: starts(list, no tensor), ends(list, no tensor) +# 1.1 without attr(decrease) +class TestSliceOp(OpTest): + def setUp(self): + self.op_type = "slice" + self.config() + self.inputs = {"Input": self.input} + self.outputs = {"Out": self.out} + self.attrs = { + "axes": self.axes, + "starts": self.starts, + "ends": self.ends, + "infer_flags": self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [3, 3, 4] + self.axes = [0, 1, 2] + self.infer_flags = [1, 1, 1] + self.out = self.input[1:3, 0:3, 2:4, :] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(["Input"], "Out", max_relative_error=0.006) + + +class TestCase1(TestSliceOp): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [-3, 0, 2] + self.ends = [3, 100, -1] + self.axes = [0, 1, 2] + self.infer_flags = [1, 1, 1] + self.out = self.input[-3:3, 0:100, 2:-1, :] + + +class TestCase2(TestSliceOp): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [-3, 0, 2] + self.ends = [3, 100, -1] + self.axes = [0, 1, 3] + self.infer_flags = [1, 1, 1] + self.out = self.input[-3:3, 0:100, :, 2:-1] + + +# 1.2 with attr(decrease) +class TestSliceOp_decs_dim(OpTest): + def setUp(self): + self.op_type = "slice" + self.config() + self.inputs = {"Input": self.input} + self.outputs = {"Out": self.out} + self.attrs = { + "axes": self.axes, + "starts": self.starts, + "ends": self.ends, + "infer_flags": self.infer_flags, + "decrease_axis": self.decrease_axis, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0] + self.infer_flags = [1, 1, 1] + self.out = self.input[1, 0:3, 2:4, :] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(["Input"], "Out", max_relative_error=0.006) + + +class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [2, 1, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0, 1] + self.infer_flags = [1, 1, 1] + self.out = self.input[1, 0, 2:4, :] + + +class TestSliceOp_decs_dim_3(TestSliceOp_decs_dim): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [-1, 0, 2] + self.ends = [1000000, 1, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0, 1] + self.infer_flags = [1, 1, 1] + self.out = self.input[-1, 0, 2:4, :] + + +class TestSliceOp_decs_dim_4(TestSliceOp_decs_dim): + def config(self): + self.input = np.random.random([3, 4, 5, 7]).astype("float64") + self.starts = [0, 1, 2, 3] + self.ends = [1, 2, 3, 4] + self.axes = [0, 1, 2, 3] + self.decrease_axis = [0, 1, 2, 3] + self.infer_flags = [1, 1, 1] + self.out = self.input[0, 1, 2, 3:4] + + +class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [-1] + self.ends = [1000000] + self.axes = [3] + self.decrease_axis = [3] + self.infer_flags = [1, 1, 1] + self.out = self.input[:, :, :, -1] + + +class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [0, 1, 2, 3] + self.ends = [1, 2, 3, 4] + self.axes = [0, 1, 2, 3] + self.decrease_axis = [0, 1, 2, 3] + self.infer_flags = [1, 1, 1] + self.out = self.input[0, 1, 2, 3:4] + + +# Situation 2: starts(list, have tensor), ends(list, no tensor) +# without attr(decrease) +class TestSliceOp_starts_ListTensor(OpTest): + def setUp(self): + self.op_type = "slice" + self.config() + + starts_tensor = [] + for index, ele in enumerate(self.starts): + starts_tensor.append(("x" + str(index), np.ones((1)).astype("int64") * ele)) + + self.inputs = {"Input": self.input, "StartsTensorList": starts_tensor} + self.outputs = {"Out": self.out} + self.attrs = { + "axes": self.axes, + "starts": self.starts_infer, + "ends": self.ends, + "infer_flags": self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [3, 3, 4] + self.axes = [0, 1, 2] + self.infer_flags = [-1, 1, -1] + self.out = self.input[1:3, 0:3, 2:4, :] + + self.starts_infer = [-1, 0, -1] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(["Input"], "Out", max_relative_error=0.006) + + +# Situation 2: starts(list, have tensor), ends(list, no tensor) +# with attr(decrease) +class TestSliceOp_decs_dim_starts_ListTensor(OpTest): + def setUp(self): + self.op_type = "slice" + self.config() + + starts_tensor = [] + for index, ele in enumerate(self.starts): + starts_tensor.append(("x" + str(index), np.ones((1)).astype("int32") * ele)) + + self.inputs = {"Input": self.input, "StartsTensorList": starts_tensor} + + self.outputs = {"Out": self.out} + self.attrs = { + "axes": self.axes, + "starts": self.starts_infer, + "ends": self.ends, + "infer_flags": self.infer_flags, + "decrease_axis": self.decrease_axis, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0] + self.infer_flags = [1, -1, 1] + self.out = self.input[1, 0:3, 2:4, :] + + self.starts_infer = [1, -1, 2] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(["Input"], "Out", max_relative_error=0.006) + + +class TestSliceOp_decs_dim_5_starts_ListTensor(TestSliceOp_decs_dim_starts_ListTensor): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [-1] + self.ends = [1000000] + self.axes = [3] + self.decrease_axis = [3] + self.infer_flags = [-1] + self.out = self.input[:, :, :, -1] + + self.starts_infer = [-1] + + +# Situation 3: starts(tensor), ends(list, no tensor) +# with attr(decrease) +class TestSliceOp_decs_dim_starts_OneTensor(OpTest): + def setUp(self): + self.op_type = "slice" + self.config() + self.inputs = { + "Input": self.input, + "StartsTensor": np.array(self.starts, dtype="int32"), + } + self.outputs = {"Out": self.out} + self.attrs = { + "axes": self.axes, + "ends": self.ends, + "infer_flags": self.infer_flags, + "decrease_axis": self.decrease_axis, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0] + self.infer_flags = [-1, -1, -1] + self.out = self.input[1, 0:3, 2:4, :] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(["Input"], "Out", max_relative_error=0.006) + + +# Situation 4: starts(tensor), ends(tensor) +# without attr(decrease) +class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): + def setUp(self): + self.op_type = "slice" + self.config() + + self.inputs = { + "Input": self.input, + "StartsTensor": np.array(self.starts, dtype="int64"), + "EndsTensor": np.array(self.ends, dtype="int32"), + } + self.outputs = {"Out": self.out} + self.attrs = { + "axes": self.axes, + "infer_flags": self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [3, 3, 4] + self.axes = [0, 1, 2] + self.infer_flags = [-1, -1, -1] + self.out = self.input[1:3, 0:3, 2:4, :] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(["Input"], "Out", max_relative_error=0.006) + + +# Situation 5: starts(tensor), ends(tensor) +# with attr(decrease) +class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): + def setUp(self): + self.op_type = "slice" + self.config() + self.inputs = { + "Input": self.input, + "StartsTensor": np.array(self.starts, dtype="int32"), + "EndsTensor": np.array(self.ends, dtype="int32"), + } + self.outputs = {"Out": self.out} + self.attrs = { + "axes": self.axes, + "infer_flags": self.infer_flags, + "decrease_axis": self.decrease_axis, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [2, 1, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0, 1] + self.infer_flags = [-1, -1, -1] + self.out = self.input[1, 0, 2:4, :] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(["Input"], "Out", max_relative_error=0.006) + + +# Situation 6: starts(tensor), ends(list, have tensor) +# without attr(decrease) +class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): + def setUp(self): + self.op_type = "slice" + self.config() + + ends_tensor = [] + for index, ele in enumerate(self.ends): + ends_tensor.append(("y" + str(index), np.ones((1)).astype("int32") * ele)) + + self.inputs = { + "Input": self.input, + "StartsTensor": np.array(self.starts, dtype="int32"), + "EndsTensorList": ends_tensor, + } + self.outputs = {"Out": self.out} + self.attrs = { + "axes": self.axes, + "ends": self.ends_infer, + "infer_flags": self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [3, 3, 4] + self.axes = [0, 1, 2] + self.infer_flags = [-1, -1, -1] + self.out = self.input[1:3, 0:3, 2:4, :] + + self.ends_infer = [-1, 3, 4] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(["Input"], "Out", max_relative_error=0.006) + + +# Test python API +class TestSliceAPI(unittest.TestCase): + def test_1(self): + input = np.random.random([3, 4, 5, 6]).astype("float64") + minus_1 = fluid.layers.fill_constant([1], "int32", -1) + minus_3 = fluid.layers.fill_constant([1], "int64", -3) + starts = fluid.layers.data(name="starts", shape=[1, 3], append_batch_size=False) + ends = fluid.layers.data(name="ends", shape=[3], append_batch_size=False) + + x = fluid.layers.data( + name="x", shape=[3, 4, 5, 6], append_batch_size=False, dtype="float64" + ) + + # value_int64 is greater than 2147483647 which is the max of int32 + value_int64 = fluid.layers.fill_constant([1], "int64", 2147483648) + + out_1 = paddle.slice( + x, axes=[0, 1, 2], starts=[-3, 0, 2], ends=[value_int64, 100, -1] + ) + out_2 = paddle.slice( + x, axes=[0, 1, 3], starts=[minus_3, 0, 2], ends=[3, 100, -1] + ) + out_3 = paddle.slice( + x, axes=[0, 1, 3], starts=[minus_3, 0, 2], ends=[3, 100, minus_1] + ) + out_4 = paddle.slice(x, axes=[0, 1, 2], starts=starts, ends=ends) + + out_5 = x[-3:3, 0:100, 2:-1] + out_6 = x[minus_3:3, 0:100, :, 2:-1] + out_7 = x[minus_1, 0:100, :, 2:minus_1] + + exe = fluid.Executor(place=fluid.CustomPlace("intel_gpu", 0)) + res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( + fluid.default_main_program(), + feed={ + "x": input, + "starts": np.array([-3, 0, 2]).astype("int32"), + "ends": np.array([3, 100, -1]).astype("int32"), + }, + fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7], + ) + + assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :]) + assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1]) + assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1]) + assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :]) + assert np.array_equal(res_5, input[-3:3, 0:100, 2:-1, :]) + assert np.array_equal(res_6, input[-3:3, 0:100, :, 2:-1]) + assert np.array_equal(res_7, input[-1, 0:100, :, 2:-1]) + + +class TestSliceApiWithTensor(unittest.TestCase): + def test_starts_ends_is_tensor(self): + with paddle.fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + a = paddle.rand(shape=[4, 5, 6], dtype="float32") + axes = [0, 1, 2] + starts = [-3, 0, 2] + ends = [3, 2, 4] + a_1 = paddle.slice( + a, + axes=axes, + starts=paddle.to_tensor(starts, dtype="int32"), + ends=paddle.to_tensor(ends, dtype="int32"), + ) + a_2 = paddle.slice(a, axes=axes, starts=starts, ends=ends) + + self.assertTrue(np.array_equal(a_1.numpy(), a_2.numpy())) + + def test_bool_tensor(self): + with paddle.fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + array = (np.arange(60).reshape([3, 4, 5]) % 3).astype("bool") + tt = paddle.to_tensor(array) + tt.stop_gradient = False + + starts = [0, 1, 2] + ends = [3, 5, 4] + axes = [0, 1, 2] + + y_paddle = paddle.slice(tt, axes, starts, ends) + y_np = tt[0:3, 1:5, 2:4] + + self.assertTrue(paddle.bool == y_paddle.dtype) + self.assertTrue(np.array_equal(y_paddle.numpy(), y_np)) + + +# unknown bug (wangran16) + + +class TestSliceApiEager(unittest.TestCase): + def test_slice_api(self): + with paddle.fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + a = paddle.rand(shape=[4, 5, 6], dtype="float32") + a.stop_gradient = False + axes = [0, 1, 2] + starts = [-3, 0, 2] + ends = [3, 2, 4] + a_1 = paddle.slice(a, axes=axes, starts=starts, ends=ends) + + a_2 = paddle.slice( + a, + axes=axes, + starts=paddle.to_tensor(starts), + ends=paddle.to_tensor(ends), + ) + + a_1.backward() + grad_truth = paddle.zeros_like(a) + grad_truth[-3:3, 0:2, 2:4] = 1 + self.assertTrue(np.array_equal(grad_truth, a.gradient())) + + self.assertTrue(np.allclose(a_1.numpy(), a[-3:3, 0:2, 2:4])) + + +# TODO(Zhiwei35): the case's expected dispatch to phi kernels, but fluid +# kernels, I think this is paddle bug? +# UnimplementedError: There are no kernels which are registered in the memcpy_d2h operator. +# [Hint: Expected kernels_iter != all_op_kernels.end(), but received +# kernels_iter == all_op_kernels.end().] (at /home/gta/chaofanl/PaddleIntelGPUDevice/ +# Paddle/paddle/fluid/framework/operator.cc:1893) +# [operator < slice > error] +# class TestSliceApiWithLoDTensorArray(unittest.TestCase): +# def setUp(self): +# self.shape = (3, 4) +# self.data = np.random.random(size=self.shape).astype('float32') +# self.idx = 0 +# self.start = 0 +# self.end = 2 +# self.axis = 1 + +# self.place = fluid.CustomPlace( +# 'intel_gpu', +# 0) if fluid.is_compiled_with_cuda() else fluid.CustomPlace( +# 'intel_gpu', 0) +# self.exe = fluid.Executor(self.place) + +# def set_program_and_run(self, main_program, case_num): +# with fluid.program_guard(main_program): +# x = [ +# fluid.data( +# name='x0', shape=self.shape, dtype="float32"), fluid.data( +# name='x1', shape=self.shape, dtype="float32"), +# fluid.data( +# name='x2', shape=self.shape, dtype="float32") +# ] + +# for each_x in x: +# each_x.stop_gradient = False + +# arr = layers.create_array(dtype="float32") +# for i in range(3): +# idx = layers.array_length(arr) +# arr = layers.array_write(x=x[i], i=idx, array=arr) + +# if case_num == 1: +# self.sliced_arr = output = arr[0] + +# elif case_num == 2: +# end = fluid.layers.array_length( +# arr) - 1 # dtype of end is int64 +# self.sliced_arr = slice_arr = arr[self.start:end] +# output, _ = fluid.layers.tensor_array_to_tensor( +# slice_arr, axis=self.axis, use_stack=True) +# elif case_num == 3: +# value_int64 = fluid.layers.fill_constant([1], "int64", +# 2147483648) +# self.sliced_arr = slice_arr = arr[self.start:value_int64] +# output, _ = fluid.layers.tensor_array_to_tensor( +# slice_arr, axis=self.axis, use_stack=True) + +# loss = fluid.layers.reduce_sum(output) +# fluid.backward.append_backward(loss) +# g_vars = list( +# map(main_program.global_block().var, +# [each_x.name + "@GRAD" for each_x in x])) +# self.out, self.g_x0, self.g_x1, self.g_x2 = \ +# self.exe.run(main_program, +# feed = {'x0': self.data, +# 'x1': self.data, +# 'x2': self.data}, +# fetch_list=[output] + g_vars) + +# def test_case_1(self): +# main_program = fluid.Program() +# self.set_program_and_run(main_program, 1) + +# self.assertTrue(self.sliced_arr.type == core.VarDesc.VarType.LOD_TENSOR) +# self.assertEqual(self.sliced_arr.shape, self.shape) +# self.assertTrue(np.array_equal(self.out, self.data)) +# self.assertTrue(np.array_equal(self.g_x0, np.ones_like(self.data))) +# self.assertTrue(np.array_equal(self.g_x1, np.zeros_like(self.data))) +# self.assertTrue(np.array_equal(self.g_x2, np.zeros_like(self.data))) + +# def test_case_2(self): +# main_program = fluid.Program() +# self.set_program_and_run(main_program, 2) + +# self.assertTrue( +# self.sliced_arr.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY) +# self.assertEqual(self.sliced_arr.shape, self.shape) +# self.assertTrue( +# np.array_equal( +# self.out, np.stack( +# [self.data, self.data], axis=self.axis))) +# self.assertTrue(np.array_equal(self.g_x0, np.ones_like(self.data))) +# self.assertTrue(np.array_equal(self.g_x1, np.ones_like(self.data))) +# self.assertTrue(np.array_equal(self.g_x2, np.zeros_like(self.data))) + +# def test_case_3(self): +# main_program = fluid.Program() +# self.set_program_and_run(main_program, 3) + +# self.assertTrue( +# self.sliced_arr.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY) +# self.assertEqual(self.sliced_arr.shape, self.shape) +# self.assertTrue( +# np.array_equal( +# self.out, +# np.stack( +# [self.data, self.data, self.data], axis=self.axis))) +# self.assertTrue(np.array_equal(self.g_x0, np.ones_like(self.data))) +# self.assertTrue(np.array_equal(self.g_x1, np.ones_like(self.data))) +# self.assertTrue(np.array_equal(self.g_x2, np.ones_like(self.data))) + + +class TestImperativeVarBaseGetItem(unittest.TestCase): + def test_getitem_with_long(self): + with fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + data = np.random.random((2, 80, 16128)).astype("float32") + var = fluid.dygraph.to_variable(data) + sliced = var[:, 10:, : var.shape[1]] # var.shape[1] is 80L here + self.assertEqual(sliced.shape, [2, 70, 80]) + + sliced = var[:, var.shape[0] :, var.shape[0] : var.shape[1]] + self.assertEqual(sliced.shape, [2, 78, 78]) + + def test_getitem_with_float(self): + def test_float_in_slice_item(): + with fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + data = np.random.random((2, 80, 16128)).astype("float32") + var = fluid.dygraph.to_variable(data) + sliced = var[:, 1.1:, : var.shape[1]] + + self.assertRaises(Exception, test_float_in_slice_item) + + def test_float_in_index(): + with fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + data = np.random.random((2, 80, 16128)).astype("float32") + var = fluid.dygraph.to_variable(data) + sliced = var[1.1] + + self.assertRaises(Exception, test_float_in_index) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_softmax_op.py b/backends/intel_gpu/tests/unittests/test_softmax_op.py new file mode 100644 index 000000000..e89f6da0b --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_softmax_op.py @@ -0,0 +1,167 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.nn.functional as F + +np.random.seed(10) + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +OpTest._get_places = get_places + + +def stable_softmax(x): + """Compute the softmax of vector x in a numerically stable way.""" + # clip to shiftx, otherwise, when calc loss with + # log(exp(shiftx)), may get log(0)=INF + shiftx = (x - np.max(x)).clip(-64.0) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + +def ref_softmax(x, axis=None, dtype=None): + x_t = x.copy() + if dtype is not None: + x_t = x_t.astype(dtype) + if axis is None: + axis = -1 + return np.apply_along_axis(stable_softmax, axis, x_t) + + +class TestSoftmaxOp(OpTest): + def get_x_shape(self): + return [10, 10] + + def get_axis(self): + return -1 + + def setUp(self): + self.op_type = "softmax" + self.use_cudnn = False + self.use_mkldnn = False + # TODO(Zhiwei35): currently only support fp32 softmax + self.dtype = np.float32 + self.init_kernel_type() + self.shape = self.get_x_shape() + self.axis = self.get_axis() + + np.random.seed(0) + x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + out = np.apply_along_axis(stable_softmax, self.axis, x) + + self.inputs = {"X": OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {"Out": out} + self.attrs = { + "axis": self.axis, + "use_cudnn": self.use_cudnn, + "use_mkldnn": self.use_mkldnn, + } + + def init_kernel_type(self): + pass + + def test_check_output(self): + self.check_output(check_dygraph=(self.use_mkldnn is False)) + + def test_check_grad(self): + self.check_grad( + ["X"], + "Out", + max_relative_error=0.01, + check_dygraph=(self.use_mkldnn is False), + ) + + +class TestSoftmaxOp2(TestSoftmaxOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + +class TestSoftmaxOp3(TestSoftmaxOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 0 + + +class TestSoftmaxOp4(TestSoftmaxOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 1 + + +class TestSoftmaxOp5(TestSoftmaxOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 2 + + +class TestSoftmaxOp6(TestSoftmaxOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 3 + + +class TestSoftmaxAPI(unittest.TestCase): + def setUp(self): + self.place = paddle.CustomPlace("intel_gpu", 0) + self.x_np = np.random.uniform(-1.0, 1.0, [2, 3, 4, 5]).astype("float32") + self.out_ref = np.apply_along_axis(stable_softmax, -1, self.x_np) + self.executed_api() + + def executed_api(self): + self.softmax = F.softmax + + def test_static_check(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data("X", self.x_np.shape, "float32") + out1 = self.softmax(x) + m = paddle.nn.Softmax() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={"X": self.x_np}, fetch_list=[out1, out2]) + out_ref = ref_softmax(self.x_np, axis=-1, dtype=None) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_check(self): + x = paddle.to_tensor(self.x_np) + out1 = self.softmax(x) + x = paddle.to_tensor(self.x_np) + m = paddle.nn.Softmax() + out2 = m(x) + out_ref = ref_softmax(self.x_np, axis=-1, dtype=None) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_transpose_op.py b/backends/intel_gpu/tests/unittests/test_transpose_op.py new file mode 100644 index 000000000..a9b18c206 --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_transpose_op.py @@ -0,0 +1,406 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +paddle.enable_static() + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +OpTest._get_places = get_places + + +class TestTransposeOp(OpTest): + def setUp(self): + self.init_op_type() + self.initTestCase() + self.python_api = paddle.transpose + self.inputs = {"X": np.random.random(self.shape).astype("float32")} + self.attrs = { + "axis": list(self.axis), + "use_mkldnn": self.use_mkldnn, + } + self.outputs = { + "XShape": np.random.random(self.shape).astype("float32"), + "Out": self.inputs["X"].transpose(self.axis), + } + + def init_op_type(self): + self.op_type = "transpose2" + self.use_mkldnn = False + + def test_check_output(self): + self.check_output(no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def initTestCase(self): + self.shape = (3, 40) + self.axis = (1, 0) + + +class TestCase0(TestTransposeOp): + def initTestCase(self): + self.shape = (100,) + self.axis = (0,) + + +class TestCase1(TestTransposeOp): + def initTestCase(self): + self.shape = (3, 4, 10) + self.axis = (0, 2, 1) + + +class TestCase2(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5) + self.axis = (0, 2, 3, 1) + + +class TestCase3(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.axis = (4, 2, 3, 1, 0) + + +class TestCase4(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6, 1) + self.axis = (4, 2, 3, 1, 0, 5) + + +class TestCase5(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 16, 96) + self.axis = (0, 2, 1) + + +class TestCase6(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 10, 12, 16) + self.axis = (3, 1, 2, 0) + + +class TestCase7(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 10, 2, 16) + self.axis = (0, 1, 3, 2) + + +class TestTransposeOpBool(TestTransposeOp): + def test_check_grad(self): + pass + + +class TestTransposeOpBool1D(TestTransposeOpBool): + def initTestCase(self): + self.shape = (100,) + self.axis = (0,) + self.inputs = {"X": np.random.random(self.shape).astype("bool")} + self.outputs = { + "XShape": np.random.random(self.shape).astype("bool"), + "Out": self.inputs["X"].transpose(self.axis), + } + + +class TestTransposeOpBool2D(TestTransposeOpBool): + def initTestCase(self): + self.shape = (3, 40) + self.axis = (1, 0) + self.inputs = {"X": np.random.random(self.shape).astype("bool")} + self.outputs = { + "XShape": np.random.random(self.shape).astype("bool"), + "Out": self.inputs["X"].transpose(self.axis), + } + + +class TestTransposeOpBool3D(TestTransposeOpBool): + def initTestCase(self): + self.shape = (3, 4, 10) + self.axis = (0, 2, 1) + self.inputs = {"X": np.random.random(self.shape).astype("bool")} + self.outputs = { + "XShape": np.random.random(self.shape).astype("bool"), + "Out": self.inputs["X"].transpose(self.axis), + } + + +class TestTransposeOpBool4D(TestTransposeOpBool): + def initTestCase(self): + self.shape = (2, 3, 4, 5) + self.axis = (0, 2, 3, 1) + self.inputs = {"X": np.random.random(self.shape).astype("bool")} + self.outputs = { + "XShape": np.random.random(self.shape).astype("bool"), + "Out": self.inputs["X"].transpose(self.axis), + } + + +class TestTransposeOpBool5D(TestTransposeOpBool): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.axis = (4, 2, 3, 1, 0) + self.inputs = {"X": np.random.random(self.shape).astype("bool")} + self.outputs = { + "XShape": np.random.random(self.shape).astype("bool"), + "Out": self.inputs["X"].transpose(self.axis), + } + + +class TestTransposeOpBool6D(TestTransposeOpBool): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6, 1) + self.axis = (4, 2, 3, 1, 0, 5) + self.inputs = {"X": np.random.random(self.shape).astype("bool")} + self.outputs = { + "XShape": np.random.random(self.shape).astype("bool"), + "Out": self.inputs["X"].transpose(self.axis), + } + + +class TestTransposeOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + x = fluid.layers.data(name="x", shape=[10, 5, 3], dtype="float32") + + def test_x_Variable_check(): + # the Input(x)'s type must be Variable + fluid.layers.transpose("not_variable", perm=[1, 0, 2]) + + self.assertRaises(TypeError, test_x_Variable_check) + + def test_x_dtype_check(): + # the Input(x)'s dtype must be one of [bool, float16, float32, float32, int32, int64] + x1 = fluid.layers.data(name="x1", shape=[10, 5, 3], dtype="int8") + fluid.layers.transpose(x1, perm=[1, 0, 2]) + + self.assertRaises(TypeError, test_x_dtype_check) + + def test_perm_list_check(): + # Input(perm)'s type must be list + fluid.layers.transpose(x, perm="[1, 0, 2]") + + self.assertRaises(TypeError, test_perm_list_check) + + def test_perm_length_and_x_dim_check(): + # Input(perm) is the permutation of dimensions of Input(input) + # its length should be equal to dimensions of Input(input) + fluid.layers.transpose(x, perm=[1, 0, 2, 3, 4]) + + self.assertRaises(ValueError, test_perm_length_and_x_dim_check) + + def test_each_elem_value_check(): + # Each element in Input(perm) should be less than Input(x)'s dimension + fluid.layers.transpose(x, perm=[3, 5, 7]) + + self.assertRaises(ValueError, test_each_elem_value_check) + + +class TestTransposeApi(unittest.TestCase): + def test_static_out(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[2, 3, 4], dtype="float32") + x_trans1 = paddle.transpose(x, perm=[1, 0, 2]) + x_trans2 = paddle.transpose(x, perm=(2, 1, 0)) + place = paddle.CustomPlace("intel_gpu", 0) + exe = paddle.static.Executor(place) + x_np = np.random.random([2, 3, 4]).astype("float32") + result1, result2 = exe.run( + feed={"x": x_np}, fetch_list=[x_trans1, x_trans2] + ) + expected_result1 = np.transpose(x_np, [1, 0, 2]) + expected_result2 = np.transpose(x_np, (2, 1, 0)) + + np.testing.assert_array_equal(result1, expected_result1) + np.testing.assert_array_equal(result2, expected_result2) + + def test_dygraph_out(self): + # This is an old test before 2.0 API so we need to disable static + # to trigger dygraph + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + x = paddle.randn([2, 3, 4]) + x_trans1 = paddle.transpose(x, perm=[1, 0, 2]) + x_trans2 = paddle.transpose(x, perm=(2, 1, 0)) + x_np = x.numpy() + expected_result1 = np.transpose(x_np, [1, 0, 2]) + expected_result2 = np.transpose(x_np, (2, 1, 0)) + + np.testing.assert_array_equal(x_trans1.numpy(), expected_result1) + np.testing.assert_array_equal(x_trans2.numpy(), expected_result2) + # This is an old test before 2.0 API so we enable static again after + # dygraph test + paddle.enable_static() + + +class TestTAPI(unittest.TestCase): + def test_out(self): + paddle.enable_static() + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[10], dtype="float32", name="data") + data_t = paddle.t(data) + place = fluid.CustomPlace("intel_gpu", 0) + exe = fluid.Executor(place) + data_np = np.random.random([10]).astype("float32") + (result,) = exe.run(feed={"data": data_np}, fetch_list=[data_t]) + expected_result = np.transpose(data_np) + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[10, 5], dtype="float32", name="data") + data_t = paddle.t(data) + place = fluid.CustomPlace("intel_gpu", 0) + exe = fluid.Executor(place) + data_np = np.random.random([10, 5]).astype("float32") + (result,) = exe.run(feed={"data": data_np}, fetch_list=[data_t]) + expected_result = np.transpose(data_np) + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[1, 5], dtype="float32", name="data") + data_t = paddle.t(data) + place = fluid.CustomPlace("intel_gpu", 0) + exe = fluid.Executor(place) + data_np = np.random.random([1, 5]).astype("float32") + (result,) = exe.run(feed={"data": data_np}, fetch_list=[data_t]) + expected_result = np.transpose(data_np) + self.assertEqual((result == expected_result).all(), True) + + with fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + np_x = np.random.random([10]).astype("float32") + data = fluid.dygraph.to_variable(np_x) + z = paddle.t(data) + np_z = z.numpy() + z_expected = np.array(np.transpose(np_x)) + self.assertEqual((np_z == z_expected).all(), True) + + with fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + np_x = np.random.random([10, 5]).astype("float32") + data = fluid.dygraph.to_variable(np_x) + z = paddle.t(data) + np_z = z.numpy() + z_expected = np.array(np.transpose(np_x)) + self.assertEqual((np_z == z_expected).all(), True) + + with fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + np_x = np.random.random([1, 5]).astype("float32") + data = fluid.dygraph.to_variable(np_x) + z = paddle.t(data) + np_z = z.numpy() + z_expected = np.array(np.transpose(np_x)) + self.assertEqual((np_z == z_expected).all(), True) + + def test_errors(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data(name="x", shape=[10, 5, 3], dtype="float32") + + def test_x_dimension_check(): + paddle.t(x) + + self.assertRaises(ValueError, test_x_dimension_check) + + +class TestMoveAxis(unittest.TestCase): + def test_moveaxis1(self): + x_np = np.random.randn(2, 3, 4, 5, 7).astype("float32") + expected = np.moveaxis(x_np, [0, 4, 3, 2], [1, 3, 2, 0]) + paddle.enable_static() + with paddle.static.program_guard(fluid.Program()): + x = paddle.static.data("x", shape=[2, 3, 4, 5, 7], dtype="float32") + out = paddle.moveaxis(x, [0, 4, 3, 2], [1, 3, 2, 0]) + + exe = paddle.static.Executor(paddle.CustomPlace("intel_gpu", 0)) + out_np = exe.run(feed={"x": x_np}, fetch_list=[out])[0] + + self.assertEqual(np.array_equal(out_np, expected), True) + + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + x = paddle.to_tensor(x_np) + out = paddle.moveaxis(x, [0, 4, 3, 2], [1, 3, 2, 0]) + self.assertEqual(out.shape, [4, 2, 5, 7, 3]) + self.assertEqual(np.array_equal(out.numpy(), expected), True) + paddle.enable_static() + + def test_moveaxis2(self): + x_np = np.random.randn(2, 3, 5).astype("float32") + expected = np.moveaxis(x_np, -2, -1) + paddle.enable_static() + with paddle.static.program_guard(fluid.Program()): + x = paddle.static.data("x", shape=[2, 3, 5], dtype="float32") + out = x.moveaxis(-2, -1) + + exe = paddle.static.Executor(paddle.CustomPlace("intel_gpu", 0)) + out_np = exe.run(feed={"x": x_np}, fetch_list=[out])[0] + + self.assertEqual(np.array_equal(out_np, expected), True) + + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + x = paddle.to_tensor(x_np) + out = x.moveaxis(-2, -1) + self.assertEqual(out.shape, [2, 5, 3]) + self.assertEqual(np.array_equal(out.numpy(), expected), True) + paddle.enable_static() + + # def test_moveaxis3(self): + # paddle.disable_static(paddle.CustomPlace('intel_gpu', 0)) + # x = paddle.to_tensor( + # [[1 + 1j, -1 - 1j], [1 + 1j, -1 - 1j], [1 + 1j, -1 - 1j]]) + # out = x.moveaxis(0, 1) + # self.assertEqual(out.shape, [2, 3]) + # paddle.enable_static() + + def test_error(self): + x = paddle.randn([2, 3, 4, 5]) + # src must have the same number with dst + with self.assertRaises(AssertionError): + paddle.moveaxis(x, [1, 0], [2]) + + # each element of src must be unique + with self.assertRaises(ValueError): + paddle.moveaxis(x, [1, 1], [0, 2]) + + # each element of dst must be unique + with self.assertRaises(ValueError): + paddle.moveaxis(x, [0, 1], [2, 2]) + + # each element of src must be integer + with self.assertRaises(AssertionError): + paddle.moveaxis(x, [0.5], [1]) + + # each element of dst must be integer + with self.assertRaises(AssertionError): + paddle.moveaxis(x, [0], [1.5]) + + # each element of src must be in the range of [-4, 3) + with self.assertRaises(AssertionError): + paddle.moveaxis(x, [-10, 1], [2, 3]) + + # each element of dst must be in the range of [-4, 3) + with self.assertRaises(AssertionError): + paddle.moveaxis(x, [2, 1], [10, 3]) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/intel_gpu/tests/unittests/test_uniform_random_op.py b/backends/intel_gpu/tests/unittests/test_uniform_random_op.py new file mode 100644 index 000000000..4ae7f0ece --- /dev/null +++ b/backends/intel_gpu/tests/unittests/test_uniform_random_op.py @@ -0,0 +1,571 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid.core as core +import paddle +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +paddle.enable_static() + + +def get_places(self): + return [paddle.CustomPlace("intel_gpu", 0)] + + +OpTest._get_places = get_places + + +def output_hist(out): + hist, _ = np.histogram(out, range=(-5, 10)) + hist = hist.astype("float32") + hist /= float(out.size) + prob = 0.1 * np.ones((10)) + return hist, prob + + +def output_hist_diag(out): + diag_num = min(out.shape) + for i in range(diag_num): + assert abs(out[i][i] - 1.0) < 1e-9 + # ignore diagonal elements + out[i][i] = 100 + hist, _ = np.histogram(out, range=(-5, 10)) + hist = hist.astype("float32") + hist /= float(out.size) + prob = 0.1 * np.ones((10)) + return hist, prob + + +class TestUniformRandomOp_attr_tensorlist(OpTest): + def setUp(self): + self.op_type = "uniform_random" + self.python_api = paddle.uniform + self.new_shape = (1000, 784) + shape_tensor = [] + for index, ele in enumerate(self.new_shape): + shape_tensor.append(("x" + str(index), np.ones((1)).astype("int64") * ele)) + self.inputs = {"ShapeTensorList": shape_tensor} + self.init_attrs() + self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} + + def init_attrs(self): + self.attrs = {"min": -5.0, "max": 10.0, "seed": 10} + self.output_hist = output_hist + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + hist, prob = self.output_hist(np.array(outs[0])) + self.assertTrue( + np.allclose(hist, prob, rtol=0, atol=0.01), "hist: " + str(hist) + ) + + +class TestMaxMinAreInt(TestUniformRandomOp_attr_tensorlist): + def init_attrs(self): + self.attrs = {"min": -5, "max": 10, "seed": 10} + self.output_hist = output_hist + + +class TestUniformRandomOp_attr_tensorlist_int32(OpTest): + def setUp(self): + self.op_type = "uniform_random" + self.python_api = paddle.uniform + self.new_shape = (1000, 784) + shape_tensor = [] + for index, ele in enumerate(self.new_shape): + shape_tensor.append(("x" + str(index), np.ones((1)).astype("int32") * ele)) + self.inputs = {"ShapeTensorList": shape_tensor} + self.init_attrs() + self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} + + def init_attrs(self): + self.attrs = {"min": -5.0, "max": 10.0, "seed": 10} + self.output_hist = output_hist + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + hist, prob = self.output_hist(np.array(outs[0])) + self.assertTrue( + np.allclose(hist, prob, rtol=0, atol=0.01), "hist: " + str(hist) + ) + + +class TestUniformRandomOp_attr_tensor(OpTest): + def setUp(self): + self.op_type = "uniform_random" + self.python_api = paddle.uniform + self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int64")} + self.init_attrs() + self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} + + def init_attrs(self): + self.attrs = {"min": -5.0, "max": 10.0, "seed": 10} + self.output_hist = output_hist + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + hist, prob = self.output_hist(np.array(outs[0])) + self.assertTrue( + np.allclose(hist, prob, rtol=0, atol=0.01), "hist: " + str(hist) + ) + + +class TestUniformRandomOp_attr_tensor_int32(OpTest): + def setUp(self): + self.op_type = "uniform_random" + self.python_api = paddle.uniform + self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int32")} + self.init_attrs() + self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} + + def init_attrs(self): + self.attrs = {"min": -5.0, "max": 10.0, "seed": 10} + self.output_hist = output_hist + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + hist, prob = self.output_hist(np.array(outs[0])) + self.assertTrue( + np.allclose(hist, prob, rtol=0, atol=0.01), "hist: " + str(hist) + ) + + +class TestUniformRandomOp(OpTest): + def setUp(self): + self.op_type = "uniform_random" + self.python_api = paddle.uniform + self.inputs = {} + self.init_attrs() + self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} + + def init_attrs(self): + self.attrs = {"shape": [1000, 784], "min": -5.0, "max": 10.0, "seed": 10} + self.output_hist = output_hist + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + hist, prob = self.output_hist(np.array(outs[0])) + self.assertTrue( + np.allclose(hist, prob, rtol=0, atol=0.01), "hist: " + str(hist) + ) + + def test_check_api(self): + places = self._get_places() + for place in places: + with fluid.dygraph.base.guard(place=place): + out = self.python_api( + self.attrs["shape"], + "float32", + self.attrs["min"], + self.attrs["max"], + self.attrs["seed"], + ) + + # def test_check_api_eager(self): + # with _test_eager_guard(): + # self.test_check_api() + + +class TestUniformRandomOpError(unittest.TestCase): + def test_errors(self): + main_prog = Program() + start_prog = Program() + with program_guard(main_prog, start_prog): + + def test_Variable(): + x1 = fluid.create_lod_tensor( + np.zeros((4, 784)), + [[1, 1, 1, 1]], + fluid.CustomPlace("intel_gpu", 0), + ) + fluid.layers.uniform_random(x1) + + self.assertRaises(TypeError, test_Variable) + + def test_Variable2(): + x1 = np.zeros((4, 784)) + fluid.layers.uniform_random(x1) + + self.assertRaises(TypeError, test_Variable2) + + def test_dtype(): + x2 = fluid.layers.data(name="x2", shape=[4, 784], dtype="float32") + fluid.layers.uniform_random(x2, "int32") + + self.assertRaises(TypeError, test_dtype) + + def test_out_dtype(): + out = fluid.layers.uniform_random(shape=[3, 4], dtype="float64") + self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64) + + test_out_dtype() + + +class TestUniformRandomOpWithDiagInit(TestUniformRandomOp): + def init_attrs(self): + self.attrs = { + "shape": [1000, 784], + "min": -5.0, + "max": 10.0, + "seed": 10, + "diag_num": 784, + "diag_step": 784, + "diag_val": 1.0, + } + self.output_hist = output_hist_diag + + +class TestUniformRandomOpSelectedRows(unittest.TestCase): + def get_places(self): + places = [core.CustomPlace("intel_gpu", 0)] + return places + + def test_check_output(self): + for place in self.get_places(): + self.check_with_place(place) + + def check_with_place(self, place): + scope = core.Scope() + out = scope.var("X").get_selected_rows() + paddle.seed(10) + op = Operator( + "uniform_random", Out="X", shape=[1000, 784], min=-5.0, max=10.0, seed=10 + ) + op.run(scope, place) + self.assertEqual(out.get_tensor().shape(), [1000, 784]) + hist, prob = output_hist(np.array(out.get_tensor())) + self.assertTrue( + np.allclose(hist, prob, rtol=0, atol=0.01), "hist: " + str(hist) + ) + + +class TestUniformRandomOpSelectedRowsWithDiagInit(TestUniformRandomOpSelectedRows): + def check_with_place(self, place): + scope = core.Scope() + out = scope.var("X").get_selected_rows() + paddle.seed(10) + op = Operator( + "uniform_random", + Out="X", + shape=[500, 784], + min=-5.0, + max=10.0, + seed=10, + diag_num=500, + diag_step=784, + diag_val=1.0, + ) + op.run(scope, place) + self.assertEqual(out.get_tensor().shape(), [500, 784]) + hist, prob = output_hist_diag(np.array(out.get_tensor())) + self.assertTrue( + np.allclose(hist, prob, rtol=0, atol=0.01), "hist: " + str(hist) + ) + + +# TODO(Zhiwei35): when solved mkl shared lib issue, the case will be enabled +# class TestUniformRandomOpApi(unittest.TestCase): +# def test_api(self): +# paddle.seed(10) +# x = fluid.layers.data('x', shape=[16], dtype='float32', lod_level=1) +# y = paddle.mm(x, +# size=16, +# param_attr=fluid.initializer.Uniform( +# low=-0.5, +# high=0.5, +# seed=10, +# diag_num=16, +# diag_step=16, +# diag_val=1.0)) + +# place = fluid.CustomPlace('intel_gpu', 0) +# x_tensor = fluid.create_lod_tensor( +# np.random.rand(3, 16).astype("float32"), [[1, 2]], place) +# exe = fluid.Executor(place) +# exe.run(fluid.default_startup_program()) +# ret = exe.run(feed={'x': x_tensor}, fetch_list=[y], return_numpy=False) + + +class TestUniformRandomOp_attr_tensor_API(unittest.TestCase): + def test_attr_tensor_API(self): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + dim_tensor = fluid.layers.fill_constant([1], "int64", 3) + ret = fluid.layers.nn.uniform_random([1, dim_tensor, 2]) + + place = fluid.CustomPlace("intel_gpu", 0) + exe = fluid.Executor(place) + + exe.run(startup_program) + outs = exe.run(train_program, fetch_list=[ret]) + + def test_attr_tensorlist_int32_API(self): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + dim_1 = fluid.layers.fill_constant([1], "int64", 3) + dim_2 = fluid.layers.fill_constant([1], "int32", 2) + ret = fluid.layers.nn.uniform_random([1, dim_1, dim_2]) + + place = fluid.CustomPlace("intel_gpu", 0) + exe = fluid.Executor(place) + + exe.run(startup_program) + outs = exe.run(train_program, fetch_list=[ret]) + + def test_attr_tensor_int32_API(self): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + shape = fluid.data(name="shape_tensor", shape=[2], dtype="int32") + ret = fluid.layers.nn.uniform_random(shape) + + place = fluid.CustomPlace("intel_gpu", 0) + exe = fluid.Executor(place) + Shape = np.array([2, 3]).astype("int32") + exe.run(startup_program) + outs = exe.run( + train_program, feed={"shape_tensor": Shape}, fetch_list=[ret] + ) + + +class TestUniformRandomOp_API_seed(unittest.TestCase): + def test_attr_tensor_API(self): + _seed = 10 + gen = paddle.seed(_seed) + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + _min = 5 + _max = 10 + + ret = fluid.layers.nn.uniform_random( + [2, 3, 2], min=_min, max=_max, seed=_seed + ) + ret_2 = fluid.layers.nn.uniform_random( + [2, 3, 2], min=_min, max=_max, seed=_seed + ) + res = fluid.layers.equal(ret, ret_2) + place = fluid.CustomPlace("intel_gpu", 0) + exe = fluid.Executor(place) + + exe.run(startup_program) + ret_value, cmp_value = exe.run(train_program, fetch_list=[ret, res]) + self.assertTrue(np.array(cmp_value).all()) + for i in ret_value.flatten(): + self.assertGreaterEqual(i, _min) + self.assertLess(i, _max) + + +class TestUniformRandomOpSelectedRowsShapeTensor(unittest.TestCase): + def get_places(self): + places = [core.CustomPlace("intel_gpu", 0)] + return places + + def test_check_output(self): + for place in self.get_places(): + self.check_with_place(place) + + def check_with_place(self, place): + scope = core.Scope() + out = scope.var("X").get_selected_rows() + shape_tensor = scope.var("Shape").get_tensor() + shape_tensor.set(np.array([1000, 784]).astype("int64"), place) + paddle.seed(10) + op = Operator( + "uniform_random", ShapeTensor="Shape", Out="X", min=-5.0, max=10.0, seed=10 + ) + op.run(scope, place) + self.assertEqual(out.get_tensor().shape(), [1000, 784]) + hist, prob = output_hist(np.array(out.get_tensor())) + self.assertTrue( + np.allclose(hist, prob, rtol=0, atol=0.01), "hist: " + str(hist) + ) + + +class TestUniformRandomOpSelectedRowsShapeTensorList(unittest.TestCase): + def get_places(self): + places = [core.CustomPlace("intel_gpu", 0)] + return places + + def test_check_output(self): + for place in self.get_places(): + self.check_with_place(place) + + def check_with_place(self, place): + scope = core.Scope() + out = scope.var("X").get_selected_rows() + shape_1 = scope.var("shape1").get_tensor() + shape_1.set(np.array([1000]).astype("int64"), place) + shape_2 = scope.var("shape2").get_tensor() + shape_2.set(np.array([784]).astype("int64"), place) + paddle.seed(10) + op = Operator( + "uniform_random", + ShapeTensorList=["shape1", "shape2"], + Out="X", + min=-5.0, + max=10.0, + seed=10, + ) + op.run(scope, place) + self.assertEqual(out.get_tensor().shape(), [1000, 784]) + hist, prob = output_hist(np.array(out.get_tensor())) + self.assertTrue( + np.allclose(hist, prob, rtol=0, atol=0.01), "hist: " + str(hist) + ) + + +class TestUniformRandomDygraphMode(unittest.TestCase): + def test_check_output(self): + with fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + x = fluid.layers.uniform_random([10], dtype="float32", min=0.0, max=1.0) + x_np = x.numpy() + for i in range(10): + self.assertTrue((x_np[i] > 0 and x_np[i] < 1.0)) + + +class TestUniformRandomBatchSizeLikeOpError(unittest.TestCase): + def test_errors(self): + main_prog = Program() + start_prog = Program() + with program_guard(main_prog, start_prog): + + def test_Variable(): + x1 = fluid.create_lod_tensor( + np.zeros((100, 784)), + [[10, 10, 10, 70]], + fluid.CustomPlace("intel_gpu", 0), + ) + fluid.layers.uniform_random_batch_size_like(x1) + + self.assertRaises(TypeError, test_Variable) + + def test_shape(): + x1 = fluid.layers.data(name="x2", shape=[100, 784], dtype="float32") + fluid.layers.uniform_random_batch_size_like(x1, shape="shape") + + self.assertRaises(TypeError, test_shape) + + def test_dtype(): + x2 = fluid.layers.data(name="x2", shape=[100, 784], dtype="float32") + fluid.layers.uniform_random_batch_size_like(x2, "int32") + + self.assertRaises(TypeError, test_dtype) + + +class TestUniformAlias(unittest.TestCase): + def test_alias(self): + paddle.uniform([2, 3], min=-5.0, max=5.0) + paddle.tensor.uniform([2, 3], min=-5.0, max=5.0) + paddle.tensor.random.uniform([2, 3], min=-5.0, max=5.0) + + def test_uniform_random(): + paddle.tensor.random.uniform_random([2, 3], min=-5.0, max=5.0) + + self.assertRaises(AttributeError, test_uniform_random) + + +class TestUniformOpError(unittest.TestCase): + def test_errors(self): + main_prog = Program() + start_prog = Program() + with program_guard(main_prog, start_prog): + + def test_Variable(): + x1 = fluid.create_lod_tensor( + np.zeros((100, 784)), + [[10, 10, 10, 70]], + fluid.CustomPlace("intel_gpu", 0), + ) + paddle.tensor.random.uniform(x1) + + self.assertRaises(TypeError, test_Variable) + + def test_Variable2(): + x1 = np.zeros((100, 784)) + paddle.tensor.random.uniform(x1) + + self.assertRaises(TypeError, test_Variable2) + + def test_dtype(): + x2 = fluid.layers.data(name="x2", shape=[100, 784], dtype="float32") + paddle.tensor.random.uniform(x2, "int32") + + self.assertRaises(TypeError, test_dtype) + + def test_out_dtype(): + out = paddle.tensor.random.uniform(shape=[3, 4], dtype="float64") + self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64) + + test_out_dtype() + + +class TestUniformDygraphMode(unittest.TestCase): + def test_check_output(self): + with fluid.dygraph.guard(paddle.CustomPlace("intel_gpu", 0)): + x = paddle.tensor.random.uniform([10], dtype="float32", min=0.0, max=1.0) + x_np = x.numpy() + for i in range(10): + self.assertTrue((x_np[i] > 0 and x_np[i] < 1.0)) + + +class TestUniformDtype(unittest.TestCase): + def test_default_dtype(self): + paddle.disable_static(paddle.CustomPlace("intel_gpu", 0)) + + def test_default_fp16(): + paddle.framework.set_default_dtype("float16") + paddle.tensor.random.uniform([2, 3]) + + self.assertRaises(TypeError, test_default_fp16) + + def test_default_fp32(): + paddle.framework.set_default_dtype("float32") + out = paddle.tensor.random.uniform([2, 3]) + self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP32) + + def test_default_fp64(): + paddle.framework.set_default_dtype("float64") + out = paddle.tensor.random.uniform([2, 3]) + self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64) + + test_default_fp64() + test_default_fp32() + + paddle.enable_static() + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()