diff --git a/.github/workflows/dockerimage.yml b/.github/workflows/dockerimage.yml index 0ba5adb3..be1ce169 100644 --- a/.github/workflows/dockerimage.yml +++ b/.github/workflows/dockerimage.yml @@ -5,10 +5,13 @@ on: branches: - master - develop + - feature/decoder pull_request: branches: - master - develop + - feature/decoder + jobs: build: diff --git a/CMakeLists.txt b/CMakeLists.txt index e5f7877b..4b17e50e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,9 +21,9 @@ set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_FLAGS "-Wall") set(CMAKE_C_FLAGS "-Wall") -set(TURBO_TRANSFORMERS_VERSION 0.2.1) +set(TURBO_TRANSFORMERS_VERSION 0.3.0) -option(WITH_PROFILER "Compile with gperftools" OFF) +option(WITH_PROFILER "Compile with profiler" OFF) option(WITH_GPU "Build with GPU" OFF) option(WITH_MODULE_BENCHMAKR "Catch2 unitest with benchmarking" ON) @@ -65,9 +65,7 @@ endif () if (WITH_PROFILER) - find_package(Gperftools REQUIRED) - include_directories(${GPERFTOOLS_INCLUDE_DIR}) - add_definitions(-DWITH_GPERFTOOLS) + add_definitions(-DWITH_PERFTOOLS) endif () IF (UNIX AND NOT APPLE) diff --git a/Dockerfile_ci b/Dockerfile_ci index 31b08585..b3304a32 100644 --- a/Dockerfile_ci +++ b/Dockerfile_ci @@ -1,14 +1,15 @@ -FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 +FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 RUN apt-get update && \ - apt-get install -y curl git wget bzip2 build-essential ninja-build g++ && rm -rf /var/lib/apt/lists/* + apt-get install -y curl git wget bzip2 build-essential ninja-build g++ gfortran && rm -rf /var/lib/apt/lists/* ENV PATH=/opt/miniconda3/bin:${PATH} CONDA_PREFIX=/opt/miniconda3 RUN curl -LO http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ bash Miniconda3-latest-Linux-x86_64.sh -p /opt/miniconda3 -b && \ rm Miniconda3-latest-Linux-x86_64.sh && \ conda update -y conda && \ - conda install pytorch==1.4.0 cudatoolkit=10.0 && \ + conda install pytorch==1.5.0 cudatoolkit=10.0 && \ + pip install OpenNMT-py && \ conda install curl conda-verify conda-build mkl-include cmake -c anaconda && \ conda install git git-lfs docopt -c conda-forge && \ conda clean -afy diff --git a/README.md b/README.md index cbad1020..f542566d 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,14 @@ ## turbo_transformers: a fast and user-friendly tool for transformer inference on CPU and GPU ![logo](./images/logo.jpeg) -### **make transformers serving fast by adding a turbo to your inference engine!** +
**make transformers serving fast by adding a turbo to your inference engine!**
+### Background Transformer is the most critical alogrithm innovation in the NLP field in recent years. It brings higher model accuracy while introduces more calculations. The efficient deployment of online Transformer-based services faces enormous challenges. In order to make the costly Transformer online service more efficient, the WeChat AI open-sourced a Transformer inference acceleration tool called TurboTransformers, which has the following characteristics. -1. Excellent CPU / GPU performance. For Intel multi-core CPU and NVIDIA GPU hardware platforms, TurboTransformers can fully utilize all levels of computing power of the hardware. It has achieved better performance over pytorch / tensorflow and current mainstream optimization engines (such as onnxruntime-mkldnn / onnxruntime-gpu, torch JIT, NVIDIA faster transformers) on a variety of CPU and GPU hardware. See the detailed benchmark results below. -2. Tailored to the characteristics of NLP inference tasks. Unlike the CV task, the input dimensions of the NLP inference task always change. The traditional approach is zero padding or truncation to a fixed length, which introduces additional zero padding computational overhead. Besides, some frameworks such as onnxruntime, tensorRT, and torchlib need to preprocess the compuatation-graph according to the input size in advance for the best performance, which is not suitable for NLP tasks with varying sizes. TurboTransformers can support variable-length input sequence processing without preprocessing. -3. A simpler method of use. TurboTransformers supports python and C ++ interface for calling. It can be used as an acceleration plug-in for pytorch. In the Transformer task, the end-to-end acceleration effect obtained by adding a few lines of python code. +1. Supporting both Transformers Encoder and Decoder. +2. Excellent CPU / GPU performance. For Intel multi-core CPU and NVIDIA GPU hardware platforms, TurboTransformers can fully utilize all levels of computing power of the hardware. It has achieved better performance over pytorch / tensorflow and current mainstream optimization engines (such as onnxruntime-mkldnn / onnxruntime-gpu, torch JIT, NVIDIA faster transformers) on a variety of CPU and GPU hardware. See the detailed benchmark results below. +3. Tailored to the characteristics of NLP inference tasks. Unlike the CV task, the input dimensions of the NLP inference task always change. The traditional approach is zero padding or truncation to a fixed length, which introduces additional zero padding computational overhead. Besides, some frameworks such as onnxruntime, tensorRT, and torchlib need to preprocess the compuatation-graph according to the input size in advance for the best performance, which is not suitable for NLP tasks with varying sizes. TurboTransformers can support variable-length input sequence processing without preprocessing. +4. A simpler method of use. TurboTransformers supports python and C++ interface for calling. It can be used as an acceleration plug-in for pytorch. In the Transformer task, the end-to-end acceleration effect obtained by adding a few lines of python code. TurboTransformers has been applied to multiple online BERT service scenarios in Tencent. For example, It brings 1.88x acceleration to the WeChat FAQ service, 2.11x acceleration to the public cloud sentiment analysis service, and 13.6x acceleration to the QQ recommendation system. @@ -45,7 +47,7 @@ sh tools/build_and_run_unittests.sh $PWD -DWITH_GPU=OFF # set(BLAS_PROVIDER "mkl" CACHE STRING "Set the blas provider library, in [openblas, mkl, blis]") ``` -Method 2:I do not want to unitest +Method 2: I do not want to unitest ``` cd /workspace mkdir -p build && cd build @@ -67,7 +69,7 @@ sh tool/build_conda_package.sh *We also prepared a docker image containing CPU version of TurboTransformers, as well as other related works, i.e. onnxrt v1.2.0 and pytorch-jit on dockerhub* ``` -docker pull thufeifeibear/turbo_transformers:0.2.0-release-cpu-dev +docker pull thufeifeibear/turbo_transformers:0.3.0-cpu-dev ``` ### Installation on GPU ``` @@ -77,8 +79,8 @@ git clone https://github.com/Tencent/TurboTransformers --recursive ``` # You can modify the environment variables in the script to specify the cuda version and operating system version sh tools/build_docker_gpu.sh $PWD -docker run --gpus all --net=host --rm -it -v $PWD:/workspace -v /etc/passwd:/etc/passwd --name=your_container_name REPOSITORY:TAG -# for example: docker run --gpus all --net=host --rm -it -v $PWD:/workspace -v /etc/passwd:/etc/passwd --name=jiarui_gpu_env ccr.ccs.tencentyun.com/mmspr/turbo_transformers:0.1.1-cuda9.0-ubuntu16.04-gpu-dev +nvidia-docker run --gpus all --net=host --rm -it -v $PWD:/workspace -v /etc/passwd:/etc/passwd --name=your_container_name REPOSITORY:TAG +# for example: nvidia-docker run --gpus all --net=host --rm -it -v $PWD:/workspace -v /etc/passwd:/etc/passwd --name=jiarui_gpu_env ccr.ccs.tencentyun.com/mmspr/turbo_transformers:0.1.1-cuda9.0-ubuntu16.04-gpu-dev ``` 2. Install pip package in docker and unitest test @@ -92,68 +94,42 @@ sh tools/build_and_run_unittests.sh $PWD -DWITH_GPU=ON cd benchmark bash gpu_run_benchmark.sh ``` -*We also prepared a docker image containing GPU version of TurboTransformers. +We also prepared a docker image containing GPU version of TurboTransformers. ``` -docker pull thufeifeibear/turbo_transformers:0.2.0-cuda10.0-cudnn7-devel-ubuntu18.04-gpu-release +docker pull thufeifeibear/turbo_transformers:0.3.0-cuda10.0-cudnn7-devel-ubuntu18.04-gpu-dev ``` ### Usage -turbo_transformers provides C ++ / python API interfaces. we hope to do our best to adapt to a variety of online environments to reduce the difficulty of development for users. +TurboTransformers provides C++ / python API interfaces. We hope to do our best to adapt to a variety of online environments to reduce the difficulty of development for users. The first step in using turbo is to load a pre-trained model. We provide a way to load pytorch and tensorflow pre-trained models in [huggingface/transformers](https://github.com/huggingface). The specific conversion method is to use the corresponding script in ./tools to convert the pre-trained model into an npz format file, and turbo uses the C ++ or python interface to load the npz format model. In particular, we consider that most of the pre-trained models are in pytorch format and used with python. We provide a shortcut for calling directly in python for the pytorch saved model. -加载预训练模型 +pretrained -#### python APIs +#### Bert Examples +##### python APIs Refer to examples in [./example/python](./example/python "python"). Since the user of BERT acceleration always requires a customized post-processing process for the task, we provide an example of how to write a sequence classification application. -#### C++ APIs +##### C++ APIs Refer to [./example/cpp](./example/cpp "C ++") for an example. Our example provides the GPU and two CPU multi-thread calling methods. One is to do one BERT inference using multiple threads; the other is to do multiple BERT inference, each of which using one thread. Users can link turbo-transformers to your code through add_subdirectory. -## Performance -### CPU -We tested the performance of TurboTransformers on three CPU hardware platforms. -We choose [pytorch](https://github.com/huggingface "pytorch"), [pytorch-jit](https://pytorch.org/docs/stable/_modules/torch/jit.html "pytorch-jit" ) and [onnxruntime-mkldnn](https://github.com/microsoft/onnxruntime "onnxruntime-mkldnn") and TensorRT implementation as a comparison. The performance test result is the average of 150 iterations. In order to avoid the phenomenon that the data of the last iteration is cached in the cache during multiple tests, each test uses random data and refreshes the cache data after calculation. -* Intel Xeon 61xx - -61xx性能 -61xx加速 - -* Intel Xeon 6133 -Compared to the 61xx model, Intel Xeon 6133 has a longer vectorized length of 512 bits, and it has a 30 MB shared L3 cache between cores. - -6133性能 -6133加速 - -### GPU -We tested the performance of turbo_transformers on four GPU hardware platforms. -We choose [pytorch](https://github.com/huggingface "pytorch"), [NVIDIA Faster Transformers](https://github.com/NVIDIA/DeepLearningExamples/tree/master/FasterTransformer "FasterTransformer"), [onnxruntime-gpu](https://github.com/microsoft/onnxruntime "onnxrt-gpu") and [TensorRT](https://github.com/NVIDIA/TensorRT/tree/release/6.0/demo/BERT) implementation as a comparison. The performance test result is the average of 150 iterations. - -* RTX 2060 -2060性能 -2060加速 +#### Decoder Examples +[TurboNLP/Translate-Demo](https://github.com/TurboNLP/Translate-Demo "translate") shows a demo of applying TurboTransformer in Translatetion Task. -* Tesla V100 - -V100性能 -V100加速 - -* Tesla P40 - -P40性能 -P40加速 +## Performance +[BERT Benchmark Results](./docs/bert.md) -* Tesla M40 +[Transformer Docoder Results](./docs/decoder.md) -M40性能 -M40加速 +[How to know hotspots of your code](./docs/profiler.md) ## TODO -Currently (April 2020), we only support a interface of the BERT encoder model using FP32. In the near futuer, we will add support for other models (GPT2, decoders, etc.) and low-precision floating point (CPU int8, GPU FP16). +Currently (June 2020), In the near futuer, we will add support for other models (Albert [Work In Progress], GPT2) and low-precision floating point (CPU int8, GPU FP16). +**Looking forwards to your contribution!** ## Lisence BSD 3-Clause License @@ -162,11 +138,10 @@ BSD 3-Clause License 1. The results of Turbo Transformers may be different from the results of PyTorch after 2 digits behind the decimal point. The diff mainly comes from Bert Output Layer. We use a approximate GELU algorithm, which may be different from PyTorch. -2. On AuthenticAMD CPU, member function `from_torch` of class `BertModelWithPooler` and `BertModel` does not support PyTorch version as 1.5.0. -In our opinion, the tensor transpose API of PyTorch is not stable. We use the following way to transpose weight matrices. -``` -weight = torch.clone(torch.t(pooler_params['dense.weight'])) -``` +### History +1. April 2020 v0.0.1, TurboTransformers released, and achieved state-of-the-art BERT inference speed on CPU/GPU. +2. June 2020 v0.2.1, TurboTransformers add BLIS as a BLAS option. Better performance on AMD CPU. +3. June 2020 v0.3.0, TurboTransformers adds support for Transformer Decoder on CPU/GPU. ## Contact us Although we recommand you post your problem with github issues, you can also join in our Turbo user group. diff --git a/benchmark/run_gpu_benchmark.sh b/benchmark/run_gpu_benchmark.sh index 0ec0a3ae..44c5ad81 100644 --- a/benchmark/run_gpu_benchmark.sh +++ b/benchmark/run_gpu_benchmark.sh @@ -18,6 +18,7 @@ FRAMEWORKS=("turbo-transformers" "torch") # FRAMEWORKS=("onnxruntime") SEQ_LEN=(10 20 40 60 80 100 200 300 400 500) BATCH_SIZE=(1 20) + N=150 MODEL="bert-base-chinese" for batch_size in ${BATCH_SIZE[*]} diff --git a/benchmark/turbo_transformers/layers/kernels/matmul_benchmark.cpp b/benchmark/turbo_transformers/layers/kernels/matmul_benchmark.cpp index 4292faf1..3794e2d6 100644 --- a/benchmark/turbo_transformers/layers/kernels/matmul_benchmark.cpp +++ b/benchmark/turbo_transformers/layers/kernels/matmul_benchmark.cpp @@ -27,7 +27,7 @@ using layers::kernels::common::FillRandom; static void MatmulBenchmarkHelper(DLDeviceType device_type, bool trans_weight, std::initializer_list weight_shape, std::vector m_list) { - constexpr int n_step = 1000; + constexpr int n_step = 100; const std::string device_name = device_type == kDLCPU ? "CPU" : "GPU"; const std::string trans_name = trans_weight ? "Tran" : "NoTrans"; @@ -61,27 +61,15 @@ static void MatmulBenchmarkHelper(DLDeviceType device_type, bool trans_weight, ss << device_name << " " << trans_name << " MatMul " << m << ", " << k << ", " << n << " "; auto g_flops = m * n * k * 2 / 1e9; - - if (device_type == kDLGPU) { -#ifdef TT_WITH_CUDA - auto flops = benchmark::TestFuncSpeed( - [&]() { - layers::kernels::MatMul(input_tensor, false, weight_tensor, - trans_weight, 1.0, &output_tensor, 0.0); - }, - n_step, ss.str(), g_flops, device_type); - - std::cout << ss.str() << " flops: " << flops << std::endl; -#endif - } else { - benchmark::TestFuncSpeed( - [&]() { - layers::kernels::MatMul(input_tensor, false, weight_tensor, - trans_weight, 1.0, &output_tensor, 0.0); - }, - n_step, ss.str(), g_flops, device_type); - } - } + auto flops = benchmark::TestFuncSpeed( + [&]() { + layers::kernels::MatMul(input_tensor, false, weight_tensor, + trans_weight, 1.0, &output_tensor, 0.0); + }, + n_step, ss.str(), g_flops, device_type); + + std::cout << ss.str() << " flops: " << flops << std::endl; + } // for } TEST_CASE("matmal-cpu-benchmark") { @@ -93,6 +81,62 @@ TEST_CASE("matmal-cpu-benchmark") { std::cout << std::endl; } +static void MatmulBenchmarkGeneralHelper(DLDeviceType device_type, + bool trans_weight, + std::vector dim_list) { + constexpr int n_step = 1000; + const std::string device_name = device_type == kDLCPU ? "CPU" : "GPU"; + const std::string trans_name = trans_weight ? "Trans" : "NoTrans"; + + for (auto m : dim_list) { + std::initializer_list input_shape{m, m}; + std::initializer_list weight_shape{m, m}; + std::initializer_list output_shape{m, m}; + + using turbo_transformers::core::NewDLPackTensorT; + + core::Tensor input_tensor( + NewDLPackTensorT(input_shape, device_type, 0)); + FillRandom(input_tensor); + + core::Tensor weight_tensor( + NewDLPackTensorT(weight_shape, device_type, 0)); + FillRandom(weight_tensor); + + core::Tensor output_tensor( + NewDLPackTensorT(output_shape, device_type, 0)); + FillRandom(output_tensor); + + std::stringstream ss; + ss << device_name << " " << trans_name << " MatMul " << m << ", " << m + << ", " << m << " "; + auto g_flops = m * m * m * 2 / 1e9; + auto flops = benchmark::TestFuncSpeed( + [&]() { + layers::kernels::MatMul(input_tensor, false, weight_tensor, + trans_weight, 1.0, &output_tensor, 0.0); + }, + n_step, ss.str(), g_flops, device_type); + + std::cout << ss.str() << " flops: " << flops << std::endl; + } // for +} + +TEST_CASE("matmal-cpu-benchmark-general") { +#if defined(TT_BLAS_USE_MKL) + std::cout << "blas uses MKL" << std::endl; +#elif defined(TT_BLAS_USE_OPENBLAS) + std::cout << "blas uses OpenBLAS" << std::endl; +#elif defined(TT_BLAS_USE_BLIS) + std::cout << "blas uses BLIS" << std::endl; +#endif + std::cout << "=================================" << std::endl; + std::cout << "CPU General MatMul Benchmark" << std::endl; + std::vector dim_list{10, 50, 100, 500, 1000, 1500, 2000}; + MatmulBenchmarkGeneralHelper(kDLCPU, false, dim_list); + std::cout << std::endl; +} + #ifdef TT_WITH_CUDA TEST_CASE("matmal-gpu-gemm7-benchmark") { diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index c21cb09b..8e74829a 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -25,4 +25,4 @@ foreach(X ${ARCH_FLAGS}) endforeach() message(STATUS "Generating CUDA code for ${CUDA_VERSION} SMs: ${CUDA_FLAGS}") -set(CMAKE_CUDA_FLAGS "${CUDA_FLAGS} -Xcompiler -Wall -std=c++11 --expt-relaxed-constexpr --use_fast_math --expt-extended-lambda") +set(CMAKE_CUDA_FLAGS "${CUDA_FLAGS} -Xcompiler -Wall --expt-relaxed-constexpr --use_fast_math --expt-extended-lambda") diff --git a/docs/bert.md b/docs/bert.md new file mode 100644 index 00000000..12c24c0b --- /dev/null +++ b/docs/bert.md @@ -0,0 +1,38 @@ +We show BERT inference performance here. + +### CPU +We tested the performance of TurboTransformers on three CPU hardware platforms. +We choose [pytorch](https://github.com/huggingface "pytorch"), [pytorch-jit](https://pytorch.org/docs/stable/_modules/torch/jit.html "pytorch-jit" ) and [onnxruntime-mkldnn](https://github.com/microsoft/onnxruntime "onnxruntime-mkldnn") and TensorRT implementation as a comparison. The performance test result is the average of 150 iterations. In order to avoid the phenomenon that the data of the last iteration is cached in the cache during multiple tests, each test uses random data and refreshes the cache data after calculation. +* Intel Xeon 61xx + +61xx性能 +61xx加速 + +* Intel Xeon 6133 +Compared to the 61xx model, Intel Xeon 6133 has a longer vectorized length of 512 bits, and it has a 30 MB shared L3 cache between cores. + +6133性能 +6133加速 + +### GPU +We tested the performance of turbo_transformers on four GPU hardware platforms. +We choose [pytorch](https://github.com/huggingface "pytorch"), [NVIDIA Faster Transformers](https://github.com/NVIDIA/DeepLearningExamples/tree/master/FasterTransformer "FasterTransformer"), [onnxruntime-gpu](https://github.com/microsoft/onnxruntime "onnxrt-gpu") and [TensorRT](https://github.com/NVIDIA/TensorRT/tree/release/6.0/demo/BERT) implementation as a comparison. The performance test result is the average of 150 iterations. + +* RTX 2060 +2060性能 +2060加速 + +* Tesla V100 + +V100性能 +V100加速 + +* Tesla P40 + +P40性能 +P40加速 + +* Tesla M40 + +M40性能 +M40加速 diff --git a/docs/decoder.md b/docs/decoder.md new file mode 100644 index 00000000..a681dfa9 --- /dev/null +++ b/docs/decoder.md @@ -0,0 +1,6 @@ +We show Transformer Decoder inference performance here. + +For a translation demo [TurboNLP/Translate-Demo](https://github.com/TurboNLP/Translate-Demo "translate"), +Turbo will bring 15.9% performance improvements on RTX 2060 GPU. + +We are still working on decoder model optimization. diff --git a/docs/profiler.md b/docs/profiler.md new file mode 100644 index 00000000..a0ab39f8 --- /dev/null +++ b/docs/profiler.md @@ -0,0 +1,68 @@ +## How to profile you code +1. Compiling code before setting option WITH_PROFILER ON in CMakeList.txt + +``` +option(WITH_PROFILER "Compile with profiler" ON) +``` +2. Add profiling context in your code, for example + +``` +with turbo_transformers.pref_guard("info") as perf: + dec_out, dec_attn = self.turbo_decoder( + decoder_in, memory_bank, memory_lengths=memory_lengths, step=step + ) +``` + +3. The profiling results will be shown on your screen, like this +``` +info Time line: +context/values/AddBiasTransposeForScore/reshape , 0.023328, 0.00835687 % +context/keys/AddBiasTransposeForScore/Reshape , 0.030464, 0.0109132 % +context/gemm2/k_out1/Reshape , 0.040384, 0.0144669 % +context/keys/AddBiasTransposeForScore , 0.04688, 0.016794 % +context/gemm1/v_out1/Reshape , 0.049568, 0.0177569 % +context/values/AddBiasTransposeForScore , 0.050304, 0.0180206 % +context/gemm2 , 0.300832, 0.107768 % +context/gemm1 , 0.322112, 0.115391 % +context/AddBiasTransposeForScore/q_out2/Reshape , 0.515776, 0.184768 % +context/AddBiasTransposeForScore/q_out1/Reshape , 0.616032, 0.220683 % +context/gemm0/q_out2/Reshape , 0.773504, 0.277095 % +self/self_value/Reshape , 0.794784, 0.284718 % +self/layernorm/Reshape , 0.801984, 0.287297 % +FFN/Reshape , 0.851904, 0.30518 % +self/self_key/Reshape , 0.986688, 0.353464 % +FFN/AddBiasAct , 1.28634, 0.460808 % +context/gemm0/q_out1/Reshape , 1.35478, 0.485329 % +self/qkv_out1/Reshape , 1.5048, 0.539069 % +context/AddBiasTransposeForScore , 1.5057, 0.53939 % +self/SplitAddBiasTransposeForScore , 1.56646, 0.561159 % +FFN/LayerNorm , 1.64701, 0.590013 % +gemm5/Reshape , 1.65885, 0.594255 % +context/gemm0/prelayernorm , 1.66512, 0.596501 % +LayerNorm , 1.6856, 0.603838 % +self/self_value/Copy , 1.68688, 0.604297 % +batch_gemm4/Reshape , 1.69667, 0.607804 % +Concat/Reshape , 1.796, 0.643387 % +ApplyMaskAndSoftmax/Reshape , 1.80499, 0.646608 % +batch_gemm3/Reshape , 2.03645, 0.729523 % +Reshape , 2.1289, 0.762641 % +self/layernorm/Copy , 2.53923, 0.909637 % +self/qkv_out2/Reshape , 2.65715, 0.95188 % +Concat , 2.76022, 0.988804 % +context/gemm0/prelayernorm/Copy , 2.83021, 1.01387 % +batch_gemm4 , 3.00442, 1.07628 % +self/self_key/Copy , 3.07203, 1.1005 % +batch_gemm3 , 3.34592, 1.19862 % +ApplyMaskAndSoftmax , 3.67014, 1.31477 % +TransposeForScore , 3.76816, 1.34988 % +FFN/AddInputBias , 3.97325, 1.42335 % +self/keys/Concat , 4.66528, 1.67126 % +self/values/Concat , 5.08947, 1.82322 % +context/gemm0 , 5.76464, 2.06509 % +self/gemm012_fused , 7.82285, 2.8024 % +gemm5 , 11.295, 4.04626 % +FFN/gemm0 , 12.2295, 4.381 % +FFN/gemm1 , 17.4551, 6.25299 % +MultiHeadedAttention_context , 60.8736, 21.8069 % +MultiHeadedAttention_self , 91.1025, 32.6359 % +``` diff --git a/example/cpp/bert_model_example.cpp b/example/cpp/bert_model_example.cpp index 335094fd..f6d3010a 100644 --- a/example/cpp/bert_model_example.cpp +++ b/example/cpp/bert_model_example.cpp @@ -11,8 +11,6 @@ // permissions and limitations under the License. // See the AUTHORS file for names of contributors. -#include "bert_model.h" - #include #include #include @@ -20,6 +18,7 @@ #include #include +#include "bert_model.h" #include "turbo_transformers/core/config.h" static bool test(const std::string &model_path, bool use_cuda = false) { @@ -54,9 +53,13 @@ static std::vector CallBackFunction( } bool test_multiple_threads(const std::string &model_path, bool only_input, - int n_threads) { + bool use_cuda, int n_threads) { std::shared_ptr model_ptr = std::make_shared(model_path, DLDeviceType::kDLCPU, 12, 12); + // input_ids, position_ids, segment_ids lengths of each row may not be the + // same. For example. std::vector> input_ids{{1, 2, 3, 4, + // 5, 6, 7}, + // {1, 2}}; std::vector> input_ids{{12166, 10699, 16752, 4454}, {5342, 16471, 817, 16022}}; std::vector> position_ids{{1, 0, 0, 0}, {1, 1, 1, 0}}; @@ -94,6 +97,10 @@ bool test_multiple_threads(const std::string &model_path, bool only_input, // bert-base-uncased (2020.04.23 version), you may need to change it to // real-time values. if (only_input) { + std::cerr << vec.data()[0] << std::endl; + std::cerr << vec.data()[1] << std::endl; + std::cerr << vec.data()[768] << std::endl; + std::cerr << vec.data()[768 + 1] << std::endl; assert(fabs(vec.data()[0] - -0.1901) < 1e-3); assert(fabs(vec.data()[1] - 0.0193) < 1e-3); assert(fabs(vec.data()[768] - 0.3060) < 1e-3); @@ -132,8 +139,10 @@ int main(int argc, char *argv[]) { turbo_transformers::core::SetNumThreads(1); if (core::IsCompiledWithCUDA()) { std::cout << "10 threads do 10 independent bert inferences." << std::endl; - test_multiple_threads(model_path, true /*use cuda*/, 10); + test_multiple_threads(model_path, false /*only_input*/, true /*use cuda*/, + 10); } - test_multiple_threads(model_path, false /*not use cuda*/, 10); + test_multiple_threads(model_path, false /*only_input*/, + false /*not use cuda*/, 10); return 0; } diff --git a/tools/convert_huggingface_bert_pytorch_to_npz.py b/tools/convert_huggingface_bert_pytorch_to_npz.py index 87ae8889..a1ff2c7f 100644 --- a/tools/convert_huggingface_bert_pytorch_to_npz.py +++ b/tools/convert_huggingface_bert_pytorch_to_npz.py @@ -50,7 +50,7 @@ def main(): arrays[k], arrays[k[:-len(q_weight_key)] + k_weight_key], arrays[k[:-len(q_weight_key)] + v_weight_key] - ], 0))) + ], 0).contiguous()).contiguous()) numpy_dict[k[:-len(q_weight_key)] + "qkv.weight"] = v.numpy() elif k.endswith(q_bias_key): v = torch.cat([ @@ -65,7 +65,8 @@ def main(): or k.endswith("pooler.dense.weight") or (k.endswith("output.dense.weight") or k.endswith("intermediate.dense.weight"))): - numpy_dict[k] = torch.clone(torch.t(arrays[k])).numpy() + numpy_dict[k] = torch.clone(torch.t( + arrays[k]).contiguous()).numpy() else: numpy_dict[k] = arrays[k].numpy() del arrays diff --git a/tools/docker/Dockerfile_dev.cpu b/tools/docker/Dockerfile_dev.cpu index fd84c3bd..9c91781e 100644 --- a/tools/docker/Dockerfile_dev.cpu +++ b/tools/docker/Dockerfile_dev.cpu @@ -1,6 +1,6 @@ # Build ONNX Runtime for benchmark FROM continuumio/miniconda3 as OnnxRTBuilder -RUN apt-get update && apt-get install -y build-essential +RUN apt-get update && apt-get install -y build-essential gfortran RUN /opt/conda/bin/conda install mklml cmake curl git numpy -c anaconda RUN mkdir /src && cd /src/ && git clone -v https://github.com/microsoft/onnxruntime.git &&\ cd onnxruntime && git checkout v1.2.0 && git submodule update --init --recursive @@ -22,6 +22,7 @@ COPY --from=PProfBuilder /go/bin/pprof /bin/pprof # NOTE: 1. MKL is installed with pytorch. # turbo-transformers will use the same MKL from PyTorch RUN /opt/conda/bin/conda install pytorch==1.4.0 cpuonly -c pytorch && \ + pip install OpenNMT-py==1.1.1 && \ /opt/conda/bin/conda install curl conda-verify conda-build mkl-include cmake -c anaconda && \ /opt/conda/bin/conda install make cmake git graphviz gperftools git-lfs docopt -c conda-forge && \ /opt/conda/bin/conda clean -afy diff --git a/tools/docker/Dockerfile_dev.gpu b/tools/docker/Dockerfile_dev.gpu index cfcb3014..8260f8dd 100644 --- a/tools/docker/Dockerfile_dev.gpu +++ b/tools/docker/Dockerfile_dev.gpu @@ -1,7 +1,7 @@ FROM IMAGE_BASE RUN apt-get update && \ - apt-get install -y curl git wget bzip2 build-essential ninja-build g++ && rm -rf /var/lib/apt/lists/* + apt-get install -y curl git wget bzip2 build-essential gfortran ninja-build g++ && rm -rf /var/lib/apt/lists/* ENV PATH=/opt/miniconda3/bin:${PATH} CONDA_PREFIX=/opt/miniconda3 RUN curl -LO http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ @@ -11,6 +11,7 @@ RUN curl -LO http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.s conda install pytorch=PYTORCH_VERSION torchvision==0.3.0 cudatoolkit=CUDA_VERSION -c pytorch && \ conda install curl conda-verify conda-build mkl-include cmake -c anaconda && \ conda install git git-lfs docopt -c conda-forge && \ + pip install OpenNMT-py==1.1.1 && \ conda clean -afy RUN pip --no-cache-dir install contexttimer future transformers docopt diff --git a/tools/docker/Dockerfile_release.cpu b/tools/docker/Dockerfile_release.cpu index 4c314009..fdca94c3 100644 --- a/tools/docker/Dockerfile_release.cpu +++ b/tools/docker/Dockerfile_release.cpu @@ -8,6 +8,7 @@ ENV PATH=/opt/conda/bin/bin:${PATH} CONDA_PREFIX=/opt/conda # NOTE: 1. MKL is installed with pytorch. # turbo-transformers will use the same MKL from PyTorch RUN /opt/conda/bin/conda install pytorch==1.4.0 cpuonly -c pytorch && \ + pip install OpenNMT-py==1.1.1 && \ /opt/conda/bin/conda install curl conda-verify conda-build mkl-include cmake -c anaconda && \ /opt/conda/bin/conda install make cmake git graphviz gperftools git-lfs docopt -c conda-forge && \ /opt/conda/bin/conda clean -afy diff --git a/turbo_transformers/core/CMakeLists.txt b/turbo_transformers/core/CMakeLists.txt index 456cdc0f..df7a6659 100644 --- a/turbo_transformers/core/CMakeLists.txt +++ b/turbo_transformers/core/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(tt_core tensor.cpp config.cpp profiler.cpp + allocator.cpp ) target_link_libraries(tt_core PUBLIC absl::stacktrace @@ -48,17 +49,15 @@ endif () if (WITH_GPU) - target_sources(tt_core PRIVATE cuda_device_context.cpp cuda_allocator.cpp) + target_sources(tt_core PRIVATE cuda_device_context.cpp) target_link_libraries(tt_core PUBLIC cudart cuda cublas) endif() -if (WITH_PROFILER) - target_link_libraries(tt_core gperftools::profiler) -endif () add_executable(tt_core_test enforce_test.cpp device_context_test.cpp tensor_test.cpp + allocator_test.cpp fp16_test.cpp) target_link_libraries(tt_core_test catch2_test_main tt_core) add_test(NAME tt_core_test COMMAND tt_core_test) diff --git a/turbo_transformers/core/allocator.cpp b/turbo_transformers/core/allocator.cpp new file mode 100644 index 00000000..7d6b071c --- /dev/null +++ b/turbo_transformers/core/allocator.cpp @@ -0,0 +1,216 @@ +// Copyright (C) 2020 THL A29 Limited, a Tencent company. +// All rights reserved. +// Licensed under the BSD 3-Clause License (the "License"); you may +// not use this file except in compliance with the License. You may +// obtain a copy of the License at +// https://opensource.org/licenses/BSD-3-Clause +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" basis, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// See the AUTHORS file for names of contributors. + +#include "turbo_transformers/core/allocator.h" + +#include + +#ifdef TT_WITH_CUDA +#include + +#include +#include + +#include "turbo_transformers/core/cuda_device_context.h" +#include "turbo_transformers/core/cuda_enforce.cuh" +#endif + +namespace turbo_transformers { +namespace core { + +struct BadAlloc : public std::exception { + explicit BadAlloc(std::string err_msg) : err_str_(err_msg) {} + + const char *what() const noexcept override { return err_str_.c_str(); } + + std::string err_str_; +}; + +#ifdef TT_WITH_CUDA +static void *cuda_alloc(size_t sz) { + void *device_mem; + try { + TT_ENFORCE_CUDA_SUCCESS(cudaMalloc((void **)&(device_mem), sz)); + } catch (...) { + throw BadAlloc("cudaMalloc failed."); + } + return device_mem; +} + +static void cuda_free(void *data) { TT_ENFORCE_CUDA_SUCCESS(cudaFree(data)); } +#endif + +namespace { +void *allocate_impl(size_t size, DLDeviceType dev) { + if (kDLCPU == dev) { + return align_alloc(size); + } else if (kDLGPU == dev) { +#ifdef TT_WITH_CUDA + auto addr = cuda_alloc(size); + return addr; +#endif + } else { + TT_THROW("Not supported devtype"); + } + return nullptr; +} + +void free_impl(void *memory_addr, DLDeviceType dev) { + if (kDLCPU == dev) { + free(memory_addr); + } else if (kDLGPU == dev) { +#ifdef TT_WITH_CUDA + cuda_free(memory_addr); +#endif + } else { + TT_THROW("Not supported devtype"); + } +} + +} // namespace + +struct Allocator::BestFitAllocatorImpl { + public: + void free_cache(size_t size, DLDeviceType dev) { + if (size == 0) return; + size_t cur = 0; + while (!allocations_.empty()) { // free the largest + auto it = --allocations_.end(); + cur += it->first; + free_impl(it->second, dev); + addr_size_map_.erase(it->second); + allocation_size_ -= it->first; + allocations_.erase(it); + if (cur >= size) return; + } + } + + void *alloc(size_t size, DLDeviceType dev) { + auto it = allocations_.lower_bound(size); + void *allocated_addr; + if (it != allocations_.end() && it->first >= size) { + allocated_addr = it->second; + allocations_.erase(it); + } else { + try { + allocated_addr = allocate_impl(size, dev); + } catch (BadAlloc &) { + free_cache(size, dev); + allocated_addr = allocate_impl(size, dev); + } + } + + addr_size_map_[allocated_addr] = size; + return allocated_addr; + } + + void free(void *data, DLDeviceType dev) { + auto size = addr_size_map_[data]; + allocations_.emplace(size, data); + allocation_size_ += size; + addr_size_map_.erase(data); + } + + BestFitAllocatorImpl() : allocation_size_(0) {} + + private: + std::multimap allocations_; + std::unordered_map addr_size_map_; + size_t allocation_size_; +}; // struct Allocator::BestFitAllocatorImpl + +struct Allocator::CachingAllocatorImpl { + void *alloc(size_t size, DLDeviceType dev) { + void *data = nullptr; + if (dev == kDLCPU) { + return allocate_impl(size, kDLCPU); + } else if (dev == kDLGPU) { +#ifdef TT_WITH_CUDA + static auto stream = core::CUDADeviceContext::GetInstance().stream(); + try { + cudaError_t result = cub_allocator.DeviceAllocate(&data, size, stream); + if (result != cudaSuccess) { + throw BadAlloc("DeviceAllocate failed."); + } + } catch (...) { + cub_allocator.FreeAllCached(); + cudaError_t result = cub_allocator.DeviceAllocate(&data, size, stream); + if (result != cudaSuccess) { + std::stringstream ss; + ss << "DeviceAllocate failed Again. " << size; + throw BadAlloc(ss.str()); + } + } +#endif + } + return data; + + } // alloc + + void free(void *data, DLDeviceType dev) { + if (dev == kDLCPU) { + free_impl(data, kDLCPU); + } else if (dev == kDLGPU) { +#ifdef TT_WITH_CUDA + try { + cudaError_t result = cub_allocator.DeviceFree(data); + if (result != cudaErrorCudartUnloading && result != cudaSuccess) { + throw std::runtime_error("DeviceFree failed "); + } + } catch (...) { + } +#endif + } + } + + private: +#ifdef TT_WITH_CUDA + cub::CachingDeviceAllocator cub_allocator; +#endif +}; // struct Allocator::CachingAllocatorImpl + +/********* + * APIs of Allocator + *********/ +Allocator::Allocator() + : bestfit_allocator_(new BestFitAllocatorImpl()), + caching_allocator_(new CachingAllocatorImpl()) {} +Allocator::~Allocator() = default; + +void *Allocator::allocate(size_t size, const std::string &strategy, + DLDeviceType dev) { + if (dev == kDLCPU) { + return allocate_impl(size, dev); + } + if ("bestfit" == strategy) { + return bestfit_allocator_->alloc(size, dev); + } else if ("cub" == strategy) { + return caching_allocator_->alloc(size, dev); + } + return nullptr; +} + +void Allocator::free(void *memory, const std::string &strategy, + DLDeviceType dev) { + if (dev == kDLCPU) { + return free_impl(memory, dev); + } + if ("bestfit" == strategy) { + bestfit_allocator_->free(memory, dev); + } else if ("cub" == strategy) { + caching_allocator_->free(memory, dev); + } +} + +} // namespace core +} // namespace turbo_transformers diff --git a/turbo_transformers/core/cuda_allocator.h b/turbo_transformers/core/allocator.h similarity index 62% rename from turbo_transformers/core/cuda_allocator.h rename to turbo_transformers/core/allocator.h index 36df1247..23787e78 100644 --- a/turbo_transformers/core/cuda_allocator.h +++ b/turbo_transformers/core/allocator.h @@ -16,32 +16,34 @@ #include #include - +#include #include "macros.h" +#include "turbo_transformers/core/memory.h" namespace turbo_transformers { namespace core { -class CUDAAllocator { +class Allocator { public: - ~CUDAAllocator(); + ~Allocator(); - static CUDAAllocator &GetInstance() { - static CUDAAllocator instance; + static Allocator &GetInstance() { + static Allocator instance; return instance; } - void *allocate(size_t size); + void *allocate(size_t size, const std::string &strategy, DLDeviceType dev); - void free(void *memory); + void free(void *memory, const std::string &strategy, DLDeviceType dev); private: - CUDAAllocator(); - - struct AllocatorImpl; - std::unique_ptr allocator_; + Allocator(); + struct BestFitAllocatorImpl; + std::unique_ptr bestfit_allocator_; + struct CachingAllocatorImpl; + std::unique_ptr caching_allocator_; - DISABLE_COPY_AND_ASSIGN(CUDAAllocator); + DISABLE_COPY_AND_ASSIGN(Allocator); }; } // namespace core diff --git a/turbo_transformers/core/allocator_test.cpp b/turbo_transformers/core/allocator_test.cpp new file mode 100644 index 00000000..8e754432 --- /dev/null +++ b/turbo_transformers/core/allocator_test.cpp @@ -0,0 +1,56 @@ +// Copyright (C) 2020 THL A29 Limited, a Tencent company. +// All rights reserved. +// Licensed under the BSD 3-Clause License (the "License"); you may +// not use this file except in compliance with the License. You may +// obtain a copy of the License at +// https://opensource.org/licenses/BSD-3-Clause +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" basis, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// See the AUTHORS file for names of contributors. +#include "turbo_transformers/core/allocator.h" + +#include + +#include "catch2/catch.hpp" +#include "turbo_transformers/core/tensor.h" + +namespace turbo_transformers { +namespace core { + +#ifdef TT_WITH_CUDA +TEST_CASE("cuda_allocator_default", "Test the default allocator for tensor") { + std::vector size_list{100, 100, 1000, 256, 200}; + std::vector addr_list(4); + for (size_t i = 0; i < size_list.size(); ++i) { + turbo_transformers::core::Tensor test_tensor( + turbo_transformers::core::NewDLPackTensorT({size_list[i]}, + kDLGPU)); + } +} + +TEST_CASE("cuda_allocator_cub", "Test the cubcaching allocator") { + Allocator &allocator = Allocator::GetInstance(); + std::vector size_list{100, 100, 1000, 256, 200}; + std::vector addr_list(4); + for (size_t i = 0; i < size_list.size(); ++i) { + addr_list[i] = allocator.allocate(size_list[i], "cub", kDLGPU); + allocator.free(addr_list[i], "cub", kDLGPU); + } +} + +TEST_CASE("cuda_allocator_bestfit", "Test the bestfit allocator") { + Allocator &allocator = Allocator::GetInstance(); + std::vector size_list{100, 100, 1000, 256, 200}; + std::vector addr_list(4); + for (size_t i = 0; i < size_list.size(); ++i) { + addr_list[i] = allocator.allocate(size_list[i], "bestfit", kDLGPU); + allocator.free(addr_list[i], "bestfit", kDLGPU); + } +} +#endif + +} // namespace core +} // namespace turbo_transformers diff --git a/turbo_transformers/core/cuda_allocator.cpp b/turbo_transformers/core/cuda_allocator.cpp deleted file mode 100644 index ff463fa2..00000000 --- a/turbo_transformers/core/cuda_allocator.cpp +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (C) 2020 THL A29 Limited, a Tencent company. -// All rights reserved. -// Licensed under the BSD 3-Clause License (the "License"); you may -// not use this file except in compliance with the License. You may -// obtain a copy of the License at -// https://opensource.org/licenses/BSD-3-Clause -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" basis, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -// implied. See the License for the specific language governing -// permissions and limitations under the License. -// See the AUTHORS file for names of contributors. -#include "turbo_transformers/core/cuda_allocator.h" -#include - -#include - -#include "turbo_transformers/core/cuda_device_context.h" - -namespace turbo_transformers { -namespace core { - -struct BadAlloc : public std::exception { - explicit BadAlloc(std::string err_msg) : err_str_(err_msg) {} - - const char *what() const noexcept override { return err_str_.c_str(); } - - std::string err_str_; -}; - -struct CUDAAllocator::AllocatorImpl { - void *alloc(size_t size) { - static auto stream = core::CUDADeviceContext::GetInstance().stream(); - void *data = nullptr; - cudaError_t result = cub_allocator.DeviceAllocate(&data, size, stream); - if (result != cudaSuccess) { - throw BadAlloc("DeviceAllocate failed."); - } - return data; - } - - void free(void *data) { - try { - cudaError_t result = cub_allocator.DeviceFree(data); - if (result != cudaErrorCudartUnloading && result != cudaSuccess) { - throw std::runtime_error("DeviceFree failed "); - } - } catch (...) { - } - } - - void free_all_cache() { cub_allocator.FreeAllCached(); } - - ~AllocatorImpl() { cub_allocator.FreeAllCached(); } - - cub::CachingDeviceAllocator cub_allocator; -}; - -CUDAAllocator::CUDAAllocator() : allocator_(new AllocatorImpl()) {} - -CUDAAllocator::~CUDAAllocator() = default; - -void *CUDAAllocator::allocate(size_t size) { - try { - return allocator_->alloc(size); - } catch (BadAlloc &) { - allocator_->free_all_cache(); - return allocator_->alloc(size); - } -} - -void CUDAAllocator::free(void *memory) { allocator_->free(memory); } - -} // namespace core -} // namespace turbo_transformers diff --git a/turbo_transformers/core/profiler.cpp b/turbo_transformers/core/profiler.cpp index 4d0f5c73..d0e5e87f 100644 --- a/turbo_transformers/core/profiler.cpp +++ b/turbo_transformers/core/profiler.cpp @@ -16,32 +16,154 @@ #include "enforce.h" #include "loguru.hpp" -#ifdef WITH_GPERFTOOLS -#include "gperftools/profiler.h" +#ifdef WITH_PERFTOOLS +#include +#include +#include +#include +#ifdef TT_WITH_CUDA +#include "turbo_transformers/core/cuda_device_context.h" +#endif #endif namespace turbo_transformers { namespace core { -#ifdef WITH_GPERFTOOLS -static bool gProfileStarted = false; +#ifdef WITH_PERFTOOLS +static bool gProfileEnabled = false; + +static bool comp(std::pair a, + std::pair b) { + return a.second < b.second; +} + +struct Profiler::ProfilerImpl { + void start_profile(const std::string& ctx_name, DLDeviceType dev_type) { + if (kDLGPU == dev_type) { +#ifdef TT_WITH_CUDA + cudaEvent_t start_event; + static auto stream = core::CUDADeviceContext::GetInstance().stream(); + cudaEventCreate(&start_event); + cudaEventRecord(start_event, stream); + event_stack_.push(start_event); +#endif + } else if (kDLCPU == dev_type) { + auto start = std::chrono::system_clock::now(); + clock_stack_.push(start); + } + } + void end_profile(const std::string& ctx_name, DLDeviceType dev_type) { + float elapsed_time; + if (kDLGPU == dev_type) { +#ifdef TT_WITH_CUDA + cudaEvent_t stop_event; + cudaEventCreate(&stop_event); + static auto stream = core::CUDADeviceContext::GetInstance().stream(); + cudaEventRecord(stop_event, stream); + cudaEventSynchronize(stop_event); + auto start_event = event_stack_.top(); + event_stack_.pop(); + cudaEventElapsedTime(&elapsed_time, start_event, stop_event); #endif + } else if (kDLCPU == dev_type) { + auto end = std::chrono::system_clock::now(); + if (clock_stack_.empty()) + TT_THROW("Profiler %s has no start time", ctx_name.c_str()); + auto start = clock_stack_.top(); + clock_stack_.pop(); + auto duration = + std::chrono::duration_cast(end - start); + elapsed_time = float(duration.count()) * + std::chrono::microseconds::period::num / + std::chrono::microseconds::period::den; + } + + if (timer_map_.find(ctx_name) != timer_map_.end()) { + timer_map_[ctx_name] += elapsed_time; + } else { + timer_map_.insert({ctx_name, elapsed_time}); + } + } + void print_results() const { + std::cerr << std::endl << profile_name_ << " Time line: " << std::endl; + std::vector> elems(timer_map_.begin(), + timer_map_.end()); + std::sort(elems.begin(), elems.end(), comp); + float total_elapsed = 0.; + for (auto it = timer_map_.begin(); it != timer_map_.end(); ++it) { + total_elapsed += it->second; + } + for (auto it = elems.begin(); it != elems.end(); ++it) { + std::cerr << it->first << " , " << it->second << ", " + << it->second / total_elapsed * 100 << " % " << std::endl; + } + } + void clear() { + timer_map_.clear(); + while (!clock_stack_.empty()) { + clock_stack_.pop(); + } +#ifdef TT_WITH_CUDA + while (!event_stack_.empty()) { + event_stack_.pop(); + } +#endif + } + void set_name(const std::string& profile_name) { + profile_name_ = profile_name; + } + + private: + std::unordered_map timer_map_; + std::stack> clock_stack_; +#ifdef TT_WITH_CUDA + std::stack event_stack_; +#endif + std::string profile_name_; +}; + +void Profiler::start_profile(const std::string& ctx_name, + DLDeviceType dev_type) { + if (gProfileEnabled) profiler_->start_profile(ctx_name, dev_type); +} + +void Profiler::end_profile(const std::string& ctx_name, DLDeviceType dev_type) { + if (gProfileEnabled) profiler_->end_profile(ctx_name, dev_type); +} + +void Profiler::print_results() const { + if (gProfileEnabled) { + profiler_->print_results(); + } +} -void EnableGperf(const std::string &profile_file) { -#ifdef WITH_GPERFTOOLS - LOG_S(1) << "gperf tools enabled." << profile_file; - TT_ENFORCE_EQ(gProfileStarted, false, "Currently the gPerf is enabled."); - ProfilerStart(profile_file.c_str()); - gProfileStarted = true; +void Profiler::clear() { profiler_->clear(); } + +void Profiler::enable(const std::string& profile_name) { + gProfileEnabled = true; + profiler_->set_name(profile_name); +} +void Profiler::disable() { gProfileEnabled = false; } + +Profiler::~Profiler() = default; +Profiler::Profiler() : profiler_(new ProfilerImpl()) {} + +#endif +void EnableGperf(const std::string& profile_name) { +#ifdef WITH_PERFTOOLS + LOG_S(1) << "gperf tools enabled. " << profile_name; + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.clear(); + profile_ctx.enable(profile_name); #else LOG_S(WARNING) << "turbo_transformers is not compiled with gperftools."; #endif } void DisableGperf() { -#ifdef WITH_GPERFTOOLS - TT_ENFORCE_EQ(gProfileStarted, true, "Currently the gPerf is disabled."); - ProfilerStop(); - gProfileStarted = false; +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.print_results(); + profile_ctx.disable(); #else LOG_S(WARNING) << "turbo_transformers is not compiled with gperftools."; #endif diff --git a/turbo_transformers/core/profiler.h b/turbo_transformers/core/profiler.h index 80f0a6b4..92909961 100644 --- a/turbo_transformers/core/profiler.h +++ b/turbo_transformers/core/profiler.h @@ -15,9 +15,42 @@ #include +#ifdef WITH_PERFTOOLS +#include + +#include + +#include "macros.h" +#endif + namespace turbo_transformers { namespace core { +#ifdef WITH_PERFTOOLS +class Profiler { + public: + ~Profiler(); + static Profiler& GetInstance() { + static Profiler instance; + return instance; + } + void clear(); + void start_profile(const std::string& ctx_name, + DLDeviceType dev_type = kDLCPU); + void end_profile(const std::string& ctx_name, DLDeviceType dev_type = kDLCPU); + void print_results() const; + void enable(const std::string& profile_name); + void disable(); + + private: + Profiler(); + + struct ProfilerImpl; + std::unique_ptr profiler_; + + DISABLE_COPY_AND_ASSIGN(Profiler); +}; +#endif void EnableGperf(const std::string& profile_file); void DisableGperf(); diff --git a/turbo_transformers/core/tensor.cpp b/turbo_transformers/core/tensor.cpp index 6f2c389a..416713ea 100644 --- a/turbo_transformers/core/tensor.cpp +++ b/turbo_transformers/core/tensor.cpp @@ -14,9 +14,9 @@ #include "tensor.h" #ifdef TT_WITH_CUDA -#include "turbo_transformers/core/cuda_allocator.h" #include "turbo_transformers/core/cuda_device_context.h" #endif +#include "turbo_transformers/core/allocator.h" namespace turbo_transformers { namespace core { @@ -25,13 +25,12 @@ static void DLManagedTensorDeletor(DLManagedTensor *self) { return; } if (self->dl_tensor.data != nullptr) { - if (self->dl_tensor.ctx.device_type == kDLCPU) { - free(self->dl_tensor.data); - } else if (self->dl_tensor.ctx.device_type == kDLGPU) { -#ifdef TT_WITH_CUDA - CUDAAllocator &cuda_allocator = CUDAAllocator::GetInstance(); - cuda_allocator.free(self->dl_tensor.data); -#endif + if (self->dl_tensor.ctx.device_type == kDLCPU || + self->dl_tensor.ctx.device_type == kDLGPU) { + // free(self->dl_tensor.data); + Allocator &allocator = Allocator::GetInstance(); + allocator.free(self->dl_tensor.data, "bestfit", + self->dl_tensor.ctx.device_type); } } @@ -61,14 +60,10 @@ DLManagedTensor *NewDLPackTensor(const std::vector &shape_list, size_t numel = std::accumulate(shape_list.begin(), shape_list.end(), 1, std::multiplies()); - if (device == kDLCPU) { - newTensor->dl_tensor.data = align_alloc(numel * (bits / 8)); - } else if (device == kDLGPU) { -#ifdef TT_WITH_CUDA - CUDAAllocator &cuda_allocator = CUDAAllocator::GetInstance(); + if (device == kDLCPU || device == kDLGPU) { size_t size = numel * (bits / 8); - newTensor->dl_tensor.data = cuda_allocator.allocate(size); -#endif + Allocator &allocator = Allocator::GetInstance(); + newTensor->dl_tensor.data = allocator.allocate(size, "bestfit", device); } else { TT_THROW("only cpu and gpu are supported!"); } diff --git a/turbo_transformers/core/tensor.h b/turbo_transformers/core/tensor.h index 24483f5c..7192c0fd 100644 --- a/turbo_transformers/core/tensor.h +++ b/turbo_transformers/core/tensor.h @@ -24,7 +24,9 @@ #include "turbo_transformers/core/enforce.h" #include "turbo_transformers/core/half.h" #include "turbo_transformers/core/memory.h" - +#ifdef WITH_PERFTOOLS +#include "turbo_transformers/core/profiler.h" +#endif namespace turbo_transformers { namespace core { @@ -150,13 +152,20 @@ class Tensor { // FIXME(florianzhao): Maybe this func should not be named Reshape. template - T *Reshape(std::initializer_list shape_list, - DLDeviceType device_type, int device_id) { + T *Reshape(std::vector shape_list, DLDeviceType device_type, + int device_id, const std::string name = "Reshape") { // if Need Realloc +#ifdef WITH_PERFTOOLS + auto &profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, device_type); +#endif if (absl::visit(ReshapeNeedRealloc(shape_list), tensor_)) { tensor_ = details::DLManagedTensorPtr( NewDLPackTensorT(shape_list, device_type, device_id)); } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, device_type); +#endif return this->template mutableData(); } @@ -207,23 +216,25 @@ class Tensor { os << "shape: "; PrintArray(os, dl_tensor.shape, dl_tensor.ndim); os << "\n"; - os << "first 10 elems: ("; + os << "first and last 10 elems: ("; int cnt = 10; double sum = 0.; if (device_type() == kDLCPU) { + os << "CPU\n"; for (int i = 0; i < numel(); ++i) { sum += data()[i]; - if (cnt-- >= 0) os << data()[i] << ", "; + if (cnt-- >= 0 || numel() - i <= 10) os << data()[i] << ", "; } } else if (device_type() == kDLGPU) { #ifdef TT_WITH_CUDA + os << "GPU\n"; auto n = numel(); std::unique_ptr cpu_data(new T[n]); Memcpy(cpu_data.get(), data(), n * sizeof(T), MemcpyFlag::kGPU2CPU); for (int i = 0; i < n; ++i) { sum += cpu_data[i]; - if (cnt-- >= 0) os << cpu_data[i] << ", "; + if (cnt-- >= 0 || n - i <= 10) os << cpu_data[i] << ", "; } #else TT_THROW("No CUDA supported, Please Compile with TT_WITH_CUDA"); @@ -294,7 +305,7 @@ class Tensor { private: struct ReshapeNeedRealloc { public: - ReshapeNeedRealloc(const std::initializer_list &shape_list) + ReshapeNeedRealloc(const std::vector &shape_list) : shape_list_(shape_list) {} bool operator()(details::DLManagedTensorPtr &ptr) const { @@ -321,7 +332,7 @@ class Tensor { } private: - const std::initializer_list &shape_list_; + const std::vector &shape_list_; }; const DLTensor &to_dl_tensor() const { @@ -331,22 +342,5 @@ class Tensor { details::TensorPayload tensor_; }; -struct TempTensor { - TempTensor() : cpu_tensor(nullptr), gpu_tensor(nullptr) {} - - core::Tensor &GetTensor(DLContext context) { - if (context.device_type == kDLCPU) { - return cpu_tensor; - } else if (context.device_type == kDLGPU) { - return gpu_tensor; - } else { - TT_THROW("This device is not support."); - } - } - - private: - core::Tensor cpu_tensor; - core::Tensor gpu_tensor; -}; } // namespace core } // namespace turbo_transformers diff --git a/turbo_transformers/core/tensor_copy.h b/turbo_transformers/core/tensor_copy.h index 05bc3b2f..b9c32f12 100644 --- a/turbo_transformers/core/tensor_copy.h +++ b/turbo_transformers/core/tensor_copy.h @@ -14,9 +14,14 @@ #pragma once #include + #include "turbo_transformers/core/memory.h" #include "turbo_transformers/core/tensor.h" +#ifdef WITH_PERFTOOLS +#include "turbo_transformers/core/profiler.h" +#endif + namespace turbo_transformers { namespace core { template @@ -38,12 +43,20 @@ static inline void Copy(const core::Tensor &src, std::vector &dst) { core::Memcpy(dst.data(), src.data(), sizeof(T) * src.numel(), flag); } template -static inline void Copy(const core::Tensor &src, core::Tensor &dst) { +static inline void Copy(const core::Tensor &src, core::Tensor &dst, + const std::string name = "Copy") { +#ifdef WITH_PERFTOOLS + auto &profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, src.device_type()); +#endif TT_ENFORCE_EQ(dst.numel(), src.numel(), "Copy two tensors should have the same size"); auto flag = core::ToMemcpyFlag(dst.device_type(), src.device_type()); core::Memcpy(dst.mutableData(), src.data(), sizeof(T) * src.numel(), flag); +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, src.device_type()); +#endif } } // namespace core diff --git a/turbo_transformers/layers/CMakeLists.txt b/turbo_transformers/layers/CMakeLists.txt index 29cfa696..7e1de9e3 100644 --- a/turbo_transformers/layers/CMakeLists.txt +++ b/turbo_transformers/layers/CMakeLists.txt @@ -21,6 +21,8 @@ add_library(tt_layers OBJECT sequence_pool.cpp bert_pooler.cpp prepare_bert_masks.cpp + multi_headed_attention.cpp + positionwise_ffn.cpp ) target_link_libraries(tt_layers PUBLIC tt_core tt_kernels) diff --git a/turbo_transformers/layers/README.md b/turbo_transformers/layers/README.md new file mode 100644 index 00000000..1756f83b --- /dev/null +++ b/turbo_transformers/layers/README.md @@ -0,0 +1,13 @@ +## How to add an optimized custom layer + +Take [multi_headed_attention](https://github.com/OpenNMT/OpenNMT-py/blob/b98fb3d7cb/onmt/modules/multi_headed_attn.py) as an example. + +1. implement your layer as in ./layers/multiple_headed_attention.cpp + +2. add python API in turbo_transformers/python/turbo_transformers/layers/modeling_decoder.py + +3. register in ./turbo_transformers/python/turbo_transformers/layers/__init__.py + +4. add a `py::class_` in ./turbo_transformers/python/pybind.cpp + +5. add an unitest in ./turbo_transformers/python/tests/multi_headed_attention_test.py diff --git a/turbo_transformers/layers/bert_attention.cpp b/turbo_transformers/layers/bert_attention.cpp index 7e49eda5..17342625 100644 --- a/turbo_transformers/layers/bert_attention.cpp +++ b/turbo_transformers/layers/bert_attention.cpp @@ -13,6 +13,8 @@ #include "turbo_transformers/layers/bert_attention.h" +#include + #include "loguru.hpp" #include "turbo_transformers/core/memory.h" #include "turbo_transformers/layers/kernels/common.h" @@ -20,6 +22,7 @@ #include "turbo_transformers/layers/kernels/mat_mul.h" #include "turbo_transformers/layers/kernels/softmax.h" #include "turbo_transformers/layers/kernels/transpose.h" + namespace turbo_transformers { namespace layers { @@ -27,107 +30,26 @@ static std::mutex mutex_; void BertAttention::operator()(const core::Tensor& input_tensor, const core::Tensor& attention_mask, - core::Tensor* output) const { - std::lock_guard g(mutex_); - TT_ENFORCE_EQ(kernels::common::is_same_device_ctx( - input_tensor.device_ctx(), attention_mask.device_ctx()), - true, - "The input_tensor and attention_mask should have the same " - "device type and device id."); - - TT_ENFORCE_EQ(input_tensor.n_dim(), 3, - "The input ids should be a matrix with shape [BatchSize, " - "SeqLen, HiddenSize]."); - EnforceShapeAndType(); - auto batch_size = input_tensor.shape(0); - auto seq_length = input_tensor.shape(1); - auto hidden_size = input_tensor.shape(2); - auto size_per_head = hidden_size / num_attention_heads_; - LOG_S(3) << "batch_size: " << batch_size - << ", num_head: " << num_attention_heads_ - << ", seq_length: " << seq_length << ", hidden_size: " << hidden_size - << ", size_per_head: " << size_per_head; - output->Reshape({batch_size, seq_length, hidden_size}, - input_tensor.device_type(), input_tensor.device_id()); - - // 1. temp_qkv = MatMul(input) - static core::TempTensor temp_qkv_tmp; - core::Tensor& temp_qkv = temp_qkv_tmp.GetTensor(input_tensor.device_ctx()); - temp_qkv.Reshape({3, batch_size, seq_length, hidden_size}, - input_tensor.device_type(), input_tensor.device_id()); - - kernels::MatMul(input_tensor, false, qkv_weight_, false, 1.0, &temp_qkv, 0.0); - - // 2. qkv = transpose(temp_qkv + bias) - // Since `SplitAddBiasTransposeForScore` does not support inplace, - // qkv and temp_qkv cannot be same tensor - static core::TempTensor qkv_tensor_tmp; - core::Tensor& qkv = qkv_tensor_tmp.GetTensor(input_tensor.device_ctx()); - qkv.Reshape( - {3, batch_size, num_attention_heads_, seq_length, size_per_head}, - input_tensor.device_type(), input_tensor.device_id()); - - kernels::SplitAddBiasTransposeForScore(&qkv, temp_qkv, qkv_bias_); - // 3. q = qkv[0]; k = qkv[1]; v = qkv[2]; - auto q = qkv[0]; - auto k = qkv[1]; - auto v = qkv[2]; - - // 4. att_score = softmax((q * k^T)*1/sqrt(size_per_head) + att_mask) - static core::TempTensor att_score_tmp; - core::Tensor& att_score = att_score_tmp.GetTensor(input_tensor.device_ctx()); - att_score.Reshape( - {batch_size, num_attention_heads_, seq_length, seq_length}, - input_tensor.device_type(), input_tensor.device_id()); - kernels::BatchMatMul(q, false, k, true, 1.0, &att_score, 0.0); - - kernels::ApplyMaskAndSoftmax( - &att_score, attention_mask, - 1 / std::sqrt(static_cast(size_per_head))); - // 5. ctx = v * att_score - static core::TempTensor context_layer_tmpr; - core::Tensor& context_layer = - context_layer_tmpr.GetTensor(input_tensor.device_ctx()); - context_layer.Reshape( - {batch_size, num_attention_heads_, seq_length, size_per_head}, - input_tensor.device_type(), input_tensor.device_id()); - kernels::BatchMatMul(att_score, false, v, false, 1.0, &context_layer, 0.0); - - // 6. self_att_out = transpose(ctx) - static core::TempTensor self_attr_out_tmp; - core::Tensor& self_attr_out = - self_attr_out_tmp.GetTensor(input_tensor.device_ctx()); - self_attr_out.Reshape( - {batch_size, seq_length, num_attention_heads_ * size_per_head}, - input_tensor.device_type(), input_tensor.device_id()); - - kernels::TransposeForScore(&self_attr_out, context_layer); - - // 7. output = LayerNorm(MatMul(self_att_out) + Bias) - kernels::MatMul(self_attr_out, false, dense_weight_, false, 1.0, output, 0.0); - - kernels::AddBiasLayerNorm(input_tensor, dense_bias_, - layer_norm_weight_, // gemma - layer_norm_bias_, output); + core::Tensor* output, core::Tensor* attn, + bool is_trans_weight) const { + std::unordered_map dummy{}; + core::Tensor* attn_ptr; + if (attn == nullptr) { + attn_ptr = new core::Tensor(nullptr); + } else { + attn_ptr = attn; + } + MultiHeadedAttention::operator()( + input_tensor, input_tensor, input_tensor, attention_mask, "self", output, + attn_ptr, dummy, false /* pre_layernorm */, true /* post_layernorm */, + false /* post_add_input */, is_trans_weight /* is_trans_weight */); + if (attn == nullptr) { + delete attn_ptr; + } } void BertAttention::EnforceShapeAndType() const { - if (loguru::current_verbosity_cutoff() >= 3) { - std::ostringstream os; - os << ">>>>>>>>>>>> qkv_weight_ <<<<<<<<<<<<" << std::endl; - qkv_weight_.Print(os); - os << ">>>>>>>>>>>> qkv_bias_ <<<<<<<<<<<<" << std::endl; - qkv_bias_.Print(os); - os << ">>>>>>>>>>>> dense_weight_ <<<<<<<<<<<<" << std::endl; - dense_weight_.Print(os); - os << ">>>>>>>>>>>> dense_bias_ <<<<<<<<<<<<" << std::endl; - dense_bias_.Print(os); - os << ">>>>>>>>>>>> layer_norm_weights <<<<<<<<<<<<" << std::endl; - layer_norm_weight_.Print(os); - os << ">>>>>>>>>>>> layer_norm_bias <<<<<<<<<<<<" << std::endl; - layer_norm_bias_.Print(os); - LOG_S(3) << os.str(); - } + MultiHeadedAttention::EnforceShapeAndType(); } } // namespace layers diff --git a/turbo_transformers/layers/bert_attention.h b/turbo_transformers/layers/bert_attention.h index 2bfc37b0..ebcd8af9 100644 --- a/turbo_transformers/layers/bert_attention.h +++ b/turbo_transformers/layers/bert_attention.h @@ -14,42 +14,36 @@ #pragma once #include #include - #include + #include "turbo_transformers/core/tensor.h" +#include "turbo_transformers/layers/multi_headed_attention.h" namespace turbo_transformers { namespace layers { -class BertAttention { +class BertAttention : public MultiHeadedAttention { public: BertAttention(core::Tensor qkv_weight, core::Tensor qkv_bias, core::Tensor dense_weight, core::Tensor dense_bias, core::Tensor layer_norm_weight, core::Tensor layer_norm_bias, int64_t num_attention_heads) - : qkv_weight_(std::move(qkv_weight)), //(768, 768) - qkv_bias_(std::move(qkv_bias)), - dense_weight_(std::move(dense_weight)), - dense_bias_(std::move(dense_bias)), - layer_norm_weight_(std::move(layer_norm_weight)), //(768) - layer_norm_bias_(std::move(layer_norm_bias)), - num_attention_heads_(num_attention_heads) { + : MultiHeadedAttention( + std::move(core::Tensor(nullptr)), std::move(core::Tensor(nullptr)), + std::move(core::Tensor(nullptr)), std::move(core::Tensor(nullptr)), + std::move(core::Tensor(nullptr)), std::move(core::Tensor(nullptr)), + std::move(dense_weight), std::move(dense_bias), + std::move(qkv_weight), std::move(qkv_bias), + std::move(layer_norm_weight), //(768) + std::move(layer_norm_bias), num_attention_heads) { EnforceShapeAndType(); } void EnforceShapeAndType() const; void operator()(const core::Tensor &input_tensor, - const core::Tensor &attention_mask, - core::Tensor *output) const; - - private: - core::Tensor qkv_weight_; - core::Tensor qkv_bias_; - core::Tensor dense_weight_; - core::Tensor dense_bias_; - core::Tensor layer_norm_weight_; - core::Tensor layer_norm_bias_; - int64_t num_attention_heads_; + const core::Tensor &attention_mask, core::Tensor *output, + core::Tensor *attn = nullptr, + bool is_trans_weight = false) const; }; } // namespace layers diff --git a/turbo_transformers/layers/kernels/CMakeLists.txt b/turbo_transformers/layers/kernels/CMakeLists.txt index 4eabf926..81c36c3f 100644 --- a/turbo_transformers/layers/kernels/CMakeLists.txt +++ b/turbo_transformers/layers/kernels/CMakeLists.txt @@ -13,7 +13,7 @@ add_library(tt_kernels OBJECT layer_norm.cpp softmax.cpp transpose.cpp activation.cpp - common.cpp seq_pool.cpp mat_mul.cpp) + common.cpp seq_pool.cpp mat_mul.cpp utils.cpp) target_link_libraries(tt_kernels PUBLIC tt_core) if (WITH_GPU) @@ -33,7 +33,9 @@ add_executable(tt_kernels_test softmax_test.cpp transpose_test.cpp layer_norm_test.cpp - mat_mul_test.cpp) + mat_mul_test.cpp + utils_test.cpp + gpu_utils_test.cpp) target_link_libraries(tt_kernels_test tt_kernels tt_core catch2_test_main) add_test(NAME tt_kernels_test COMMAND tt_kernels_test) diff --git a/turbo_transformers/layers/kernels/activation.cpp b/turbo_transformers/layers/kernels/activation.cpp index 4630debc..a8948fa6 100644 --- a/turbo_transformers/layers/kernels/activation.cpp +++ b/turbo_transformers/layers/kernels/activation.cpp @@ -16,6 +16,9 @@ #include "turbo_transformers/core/cuda_device_context.h" #include "turbo_transformers/layers/kernels/gpu_activation_kernel.h" #endif +#ifdef WITH_PERFTOOLS +#include "turbo_transformers/core/profiler.h" +#endif namespace turbo_transformers { namespace layers { @@ -66,10 +69,31 @@ void CPUAddBiasActKernel(const float *bias, vsTanh(feature_dim, &out[i * feature_dim], &out[i * feature_dim]); } } + +template <> +void CPUAddBiasActKernel(const float *bias, + int64_t batch_size, + int64_t feature_dim, + float *out) { +#pragma omp parallel for + for (int64_t i = 0; i < batch_size; ++i) { + int64_t k = 0; +#pragma omp simd + for (int64_t j = feature_dim * i; j < feature_dim * (i + 1); ++j) { + out[j] = out[j] + bias[k++]; + out[j] = out[j] > 0. ? out[j] : 0.; + } + } +} } // namespace template -void AddBiasAct(const core::Tensor &bias_tensor, core::Tensor *out_tensor) { +void AddBiasAct(const core::Tensor &bias_tensor, core::Tensor *out_tensor, + const std::string name) { +#ifdef WITH_PERFTOOLS + auto &profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, bias_tensor.device_type()); +#endif auto *out = out_tensor->mutableData(); auto *bias = bias_tensor.data(); @@ -89,13 +113,23 @@ void AddBiasAct(const core::Tensor &bias_tensor, core::Tensor *out_tensor) { TT_THROW("device_type %d is not supported for AddBiasAct", out_tensor->device_type()); } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, bias_tensor.device_type()); +#endif } template void AddBiasAct( - const core::Tensor &bias_tensor, core::Tensor *out_tensor); + const core::Tensor &bias_tensor, core::Tensor *out_tensor, + const std::string name); template void AddBiasAct( - const core::Tensor &bias_tensor, core::Tensor *out_tensor); + const core::Tensor &bias_tensor, core::Tensor *out_tensor, + const std::string name); + +template void AddBiasAct( + const core::Tensor &bias_tensor, core::Tensor *out_tensor, + const std::string name); + } // namespace kernels } // namespace layers } // namespace turbo_transformers diff --git a/turbo_transformers/layers/kernels/activation.h b/turbo_transformers/layers/kernels/activation.h index 5ebbcdbb..a4121164 100644 --- a/turbo_transformers/layers/kernels/activation.h +++ b/turbo_transformers/layers/kernels/activation.h @@ -21,7 +21,8 @@ namespace kernels { using types::ActivationType; template -void AddBiasAct(const core::Tensor& bias, core::Tensor* out); +void AddBiasAct(const core::Tensor& bias, core::Tensor* out, + const std::string name = "AddBiasAct"); } // namespace kernels } // namespace layers diff --git a/turbo_transformers/layers/kernels/gpu_activation_kernel.cu b/turbo_transformers/layers/kernels/gpu_activation_kernel.cu index c3364df6..5fffcbe8 100644 --- a/turbo_transformers/layers/kernels/gpu_activation_kernel.cu +++ b/turbo_transformers/layers/kernels/gpu_activation_kernel.cu @@ -38,6 +38,13 @@ __inline__ __device__ float ActvationOp( const float& x) { return tanhf(x); } + +template <> +__inline__ __device__ float ActvationOp( + const float& x) { + return (x > 0) ? x : 0; +} + } // namespace template @@ -78,6 +85,11 @@ template void GPUAddBiasActKernel( template void GPUAddBiasActKernel( const float* bias_data, int64_t batch_size, int64_t feature_dim, cudaStream_t stream, float* out_data); + +template void GPUAddBiasActKernel( + const float* bias_data, int64_t batch_size, int64_t feature_dim, + cudaStream_t stream, float* out_data); + } // namespace kernels } // namespace layers } // namespace turbo_transformers diff --git a/turbo_transformers/layers/kernels/gpu_softmax_kernel.cu b/turbo_transformers/layers/kernels/gpu_softmax_kernel.cu index f516941c..6aa8fb08 100644 --- a/turbo_transformers/layers/kernels/gpu_softmax_kernel.cu +++ b/turbo_transformers/layers/kernels/gpu_softmax_kernel.cu @@ -62,36 +62,36 @@ struct ArrayMaxFunc { template __global__ void cub_softmax_kernel_k(float* qk_buf_, const float* attr_mask, const int batch_size, const int head_num, - const int seq_len, const float scaler) { - using CubBlockReduce = cub::BlockReduce, BlockDim>; - __shared__ typename CubBlockReduce::TempStorage temp_storage; + const int from_seq_len, + const int to_seq_len, const float scaler, + bool is_2D) { + __shared__ typename cub::BlockReduce, BlockDim>::TempStorage + temp_storage; __shared__ float s_sum[K], s_max[K]; float tmp[K]; - int qk_offset = blockIdx.x * K * seq_len; - - int batch_id = (blockIdx.x * K) / (head_num * seq_len); - int mask_offset = batch_id * seq_len; - float mask_val = - threadIdx.x < seq_len ? attr_mask[threadIdx.x + mask_offset] : 0.0f; + int qk_offset = blockIdx.x * K * to_seq_len; + float mask_val = 0.; for (int i = 0; i < K; ++i) { - float qk = threadIdx.x < seq_len - ? qk_buf_[threadIdx.x + qk_offset + seq_len * i] + float qk = threadIdx.x < to_seq_len + ? qk_buf_[threadIdx.x + qk_offset + to_seq_len * i] : 0.0f; - int next_batch_id = - i == 0 ? batch_id : (blockIdx.x * K + i) / (head_num * seq_len); - if (batch_id != next_batch_id) { - batch_id = next_batch_id; - mask_val = threadIdx.x < seq_len - ? attr_mask[threadIdx.x + batch_id * seq_len] - : 0.0f; + if (attr_mask != nullptr) { + int batch_id = (blockIdx.x * K + i) / (head_num * from_seq_len); + int from_seq_id = (blockIdx.x * K + i) % from_seq_len; + mask_val = attr_mask[threadIdx.x + + (is_2D ? (batch_id * to_seq_len) + : (batch_id * from_seq_len + from_seq_id) * + to_seq_len)]; + } else { + mask_val = 0.0f; } // mask_val = (1.0f - mask_val) * -10000.0f; - tmp[i] = threadIdx.x < seq_len ? (qk * scaler + mask_val) : -1e20f; + tmp[i] = threadIdx.x < to_seq_len ? (qk * scaler + mask_val) : -1e20f; } Array max_val = - CubBlockReduce(temp_storage) + cub::BlockReduce, BlockDim>(temp_storage) .Reduce(Array(tmp), ArrayMaxFunc()); if (threadIdx.x == 0) { @@ -103,11 +103,11 @@ __global__ void cub_softmax_kernel_k(float* qk_buf_, const float* attr_mask, float qk_tmp[K]; for (int i = 0; i < K; ++i) { - qk_tmp[i] = threadIdx.x < seq_len ? __expf((tmp[i] - s_max[i])) : 0.0f; + qk_tmp[i] = threadIdx.x < to_seq_len ? __expf((tmp[i] - s_max[i])) : 0.0f; } Array sum_val = - CubBlockReduce(temp_storage) + cub::BlockReduce, BlockDim>(temp_storage) .Reduce(Array(qk_tmp), ArrayAddFunc()); if (threadIdx.x == 0) { @@ -117,9 +117,10 @@ __global__ void cub_softmax_kernel_k(float* qk_buf_, const float* attr_mask, } __syncthreads(); - if (threadIdx.x < seq_len) { + if (threadIdx.x < to_seq_len) { for (int i = 0; i < K; ++i) { - qk_buf_[threadIdx.x + qk_offset + seq_len * i] = (qk_tmp[i] / s_sum[i]); + qk_buf_[threadIdx.x + qk_offset + to_seq_len * i] = + (qk_tmp[i] / s_sum[i]); } } } @@ -178,22 +179,23 @@ __global__ void cub_softmax_kernel_k(float* qk_buf_, const float* attr_mask, template <> void GPUSoftmaxMask(float* qk_buf, const float* attr_mask, int64_t batch_size, - int64_t head_num, int64_t seq_len, float scale, - cudaStream_t stream) { + int64_t head_num, int64_t from_seq_len, int64_t to_seq_len, + float scale, bool is_2D, cudaStream_t stream) { dim3 block, grid; - int high_dim_size = batch_size * head_num * seq_len; + int high_dim_size = batch_size * head_num * from_seq_len; const int OneRowPerThreadBlock = 1; const int RowsPerThreadBlock = 2; int row_per_thread_block = OneRowPerThreadBlock; - if ((head_num * seq_len) % RowsPerThreadBlock == 0) { + if ((head_num * from_seq_len) % RowsPerThreadBlock == 0) { row_per_thread_block = RowsPerThreadBlock; } // block size must be 32x, so warp reduce can work - block.x = (seq_len + 31) / 32 * 32; + block.x = (to_seq_len + 31) / 32 * 32; grid.x = high_dim_size / row_per_thread_block; // Because there are many function templates, the compilation speed may be // slow. - RUN_KERNEL(qk_buf, attr_mask, batch_size, head_num, seq_len, scale); + RUN_KERNEL(qk_buf, attr_mask, batch_size, head_num, from_seq_len, to_seq_len, + scale, is_2D); } #undef RUN_KERNEL #undef SOFTMAX_KERNEL_CASE diff --git a/turbo_transformers/layers/kernels/gpu_softmax_kernel.h b/turbo_transformers/layers/kernels/gpu_softmax_kernel.h index aae1d549..8aebdd1b 100644 --- a/turbo_transformers/layers/kernels/gpu_softmax_kernel.h +++ b/turbo_transformers/layers/kernels/gpu_softmax_kernel.h @@ -19,8 +19,8 @@ namespace kernels { template void GPUSoftmaxMask(T* qk_buf, const T* attr_mask, int64_t batch_size, - int64_t head_num, int64_t seq_len, float scale, - cudaStream_t stream); + int64_t head_num, int64_t from_seq_len, int64_t to_seq_len, + float scale, bool is_2D, cudaStream_t stream); } // namespace kernels } // namespace layers } // namespace turbo_transformers diff --git a/turbo_transformers/layers/kernels/gpu_transpose_kernel.cu b/turbo_transformers/layers/kernels/gpu_transpose_kernel.cu index 9ad69249..5d6b65ff 100644 --- a/turbo_transformers/layers/kernels/gpu_transpose_kernel.cu +++ b/turbo_transformers/layers/kernels/gpu_transpose_kernel.cu @@ -73,40 +73,70 @@ void GPUSplitAddBiasTransposeForScore( weight_num, size_per_head, out_data); } -static __global__ void transpose(const float* src, float* dst, - const int batch_size, const int seq_len, - const int head_num, const int size_per_head) { - int tid = threadIdx.x; - int batch_id = blockIdx.x / (head_num * seq_len); - int seq_id = blockIdx.x % seq_len; - int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; +namespace { +// batch, head, seq, size_per_head -> batch head seq size_per_head +template +__global__ void transpose(const float* src, const float* bias, + const int batch_size, const int seq_len, + const int head_num, const int size_per_head, + float* dst) { + int tid = threadIdx.x; int idx = tid; - while (idx < size_per_head) { - dst[batch_id * (head_num * seq_len * size_per_head) + - seq_id * head_num * size_per_head + head_id * size_per_head + idx] = - src[blockIdx.x * size_per_head + idx]; - idx += blockDim.x; + if (AddBias) { + int batch_id = blockIdx.x / (seq_len * head_num); + int seq_id = blockIdx.x / head_num % seq_len; + int head_id = blockIdx.x % head_num; + while (idx < size_per_head) { + dst[batch_id * (head_num * seq_len * size_per_head) + + head_id * seq_len * size_per_head + seq_id * size_per_head + idx] = + src[blockIdx.x * size_per_head + idx] + + bias[head_id * size_per_head + idx]; + idx += blockDim.x; + } + } else { + //(batch, head, seq_len, size_per_head) -> (batch, seq_len, head, + // size_per_head) + int batch_id = blockIdx.x / (head_num * seq_len); + int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; + int seq_id = blockIdx.x % seq_len; + + while (idx < size_per_head) { + dst[batch_id * (head_num * seq_len * size_per_head) + + seq_id * head_num * size_per_head + head_id * size_per_head + idx] = + src[blockIdx.x * size_per_head + idx]; + idx += blockDim.x; + } } } - +} // namespace /* (batch_size, seq_len, num_attention_heads, size_per_head) -> (batch_size, head_num, seq_len, size_per_head) */ -template <> -void GPUTransposeForScore(const float* input_data, float* output_data, +template +void GPUTransposeForScore(const T* input_data, const T* bias, int64_t batch_size, int64_t seq_len, int64_t num_attention_heads, int64_t size_per_head, - cudaStream_t stream) { + cudaStream_t stream, T* output_data) { dim3 grid, block; grid.x = batch_size * num_attention_heads * seq_len; block.x = min(1024, int(size_per_head)); - transpose<<>>(input_data, output_data, batch_size, - seq_len, num_attention_heads, - size_per_head); + transpose<<>>(input_data, bias, batch_size, + seq_len, num_attention_heads, + size_per_head, output_data); } +template void GPUTransposeForScore( + const float* input_data, const float* bias, int64_t batch_size, + int64_t seq_len, int64_t num_attention_heads, int64_t size_per_head, + cudaStream_t stream, float* output_data); + +template void GPUTransposeForScore( + const float* input_data, const float* bias, int64_t batch_size, + int64_t seq_len, int64_t num_attention_heads, int64_t size_per_head, + cudaStream_t stream, float* output_data); + } // namespace kernels } // namespace layers } // namespace turbo_transformers diff --git a/turbo_transformers/layers/kernels/gpu_transpose_kernel.h b/turbo_transformers/layers/kernels/gpu_transpose_kernel.h index 8c60c969..74285418 100644 --- a/turbo_transformers/layers/kernels/gpu_transpose_kernel.h +++ b/turbo_transformers/layers/kernels/gpu_transpose_kernel.h @@ -25,11 +25,11 @@ void GPUSplitAddBiasTransposeForScore(const T* input_data, const T* bias_data, int64_t size_per_head, cudaStream_t stream); -template -void GPUTransposeForScore(const T* input_data, T* output_data, +template +void GPUTransposeForScore(const T* input_data, const T* bias, int64_t batch_size, int64_t seq_len, int64_t num_attention_heads, int64_t size_per_head, - cudaStream_t stream); + cudaStream_t stream, T* output_data); } // namespace kernels } // namespace layers diff --git a/turbo_transformers/layers/kernels/gpu_utils.cu b/turbo_transformers/layers/kernels/gpu_utils.cu index bbff5eeb..9f6cfda0 100644 --- a/turbo_transformers/layers/kernels/gpu_utils.cu +++ b/turbo_transformers/layers/kernels/gpu_utils.cu @@ -11,8 +11,6 @@ // permissions and limitations under the License. // See the AUTHORS file for names of contributors. -#include "turbo_transformers/layers/kernels/gpu_utils.h" - #include #include #include @@ -22,6 +20,8 @@ #include #include +#include "turbo_transformers/layers/kernels/gpu_utils.h" + namespace turbo_transformers { namespace layers { namespace kernels { @@ -129,6 +129,92 @@ void GPUTransform(int64_t* src_data_ptr, float* dst_data_ptr, dst_data_ptr_dev_ptr, func); } +// TODO(jiaruifang) if the lowese dimension is not 32x and <= 1024, +// implementation is not optimized +template +static __global__ void add_bias(const float* input1, const float* input2, + const float* bias, int m, int n, + float* output) { + int offset = blockIdx.x * n; + int block_dim_x = blockDim.x; + + int idx = threadIdx.x; + if (AddInput) { + while (idx < n) { + output[idx + offset] = + input1[idx + offset] + input2[idx + offset] + bias[idx]; + idx += block_dim_x; + } + } else { + while (idx < n) { + output[idx + offset] = input1[idx + offset] + bias[idx]; + idx += block_dim_x; + } + } +} + +template +void GPUAddBias(const T* input1, const T* input2, const T* bias, int64_t m, + int64_t n, cudaStream_t stream, T* output) { + dim3 grid(m); + int block_size = min(1024, (int)((n + 31) / 32 * 32)); + dim3 block(block_size); + add_bias<<>>( + input1, input2, bias, m, n, output); // m : high dim, n : low dim +} + +template void GPUAddBias(const float* input1, const float* input2, + const float* bias, int64_t m, int64_t n, + cudaStream_t stream, float* output); +template void GPUAddBias(const float* input1, const float* input2, + const float* bias, int64_t m, int64_t n, + cudaStream_t stream, float* output); + +template +__global__ void concat_kernel(const Dtype* t1, const Dtype* t2, + int64_t high_dim, int64_t t1_mid_size, + int64_t t2_mid_size, int64_t low_dim, + Dtype* out_data) { + int tid = threadIdx.x; // hidden_size idx + int gid = blockIdx.x; // batch_size idx + int out_mid_dim = t1_mid_size + t2_mid_size; + int out_high_idx = gid / out_mid_dim; + int out_mid_idx = gid % out_mid_dim; + int out_low_dix = tid; + + if (out_mid_idx < t1_mid_size) { + // copy from t1 + out_data[out_high_idx * out_mid_dim * low_dim + out_mid_idx * low_dim + + out_low_dix] = t1[out_high_idx * t1_mid_size * low_dim + + out_mid_idx * low_dim + out_low_dix]; + } else { + // copy from t2 + out_data[out_high_idx * out_mid_dim * low_dim + out_mid_idx * low_dim + + out_low_dix] = + t2[out_high_idx * t2_mid_size * low_dim + + (out_mid_idx - t1_mid_size) * low_dim + out_low_dix]; + } +} + +template +void GPUConcat(const Dtype* t1, const Dtype* t2, const int64_t high_dim, + const int64_t t1_mid_size, const int64_t t2_mid_size, + const int64_t low_dim, cudaStream_t stream, Dtype* out_data) { + assert(low_dim < 1024); + dim3 grid(high_dim * (t1_mid_size + t2_mid_size)); + int block_size = std::min((int)low_dim, 1024); + dim3 block(block_size); + concat_kernel<<>>( + t1, t2, high_dim, t1_mid_size, t2_mid_size, low_dim, + out_data); // m : high dim, n : low dim +} + +template void GPUConcat(const float* t1, const float* t2, + const int64_t high_dim, + const int64_t t1_mid_size, + const int64_t t2_mid_size, const int64_t low_dim, + cudaStream_t stream, float* out_data); + } // namespace kernels } // namespace layers } // namespace turbo_transformers diff --git a/turbo_transformers/layers/kernels/gpu_utils.h b/turbo_transformers/layers/kernels/gpu_utils.h index cbbe3dfc..cd7713c5 100644 --- a/turbo_transformers/layers/kernels/gpu_utils.h +++ b/turbo_transformers/layers/kernels/gpu_utils.h @@ -13,6 +13,7 @@ #pragma once #include + #include "turbo_transformers/layers/types.h" namespace turbo_transformers { @@ -31,6 +32,14 @@ void GPUFill(T* data_ptr, int64_t size, T val); extern void GPUTransform(int64_t* src_data_ptr, float* dst_data_ptr, const int64_t size); +template +void GPUAddBias(const T* input1, const T* input2, const T* bias, int64_t m, + int64_t n, cudaStream_t stream, T* out); + +template +void GPUConcat(const Dtype* t1, const Dtype* t2, const int64_t high_dim, + const int64_t t1_mid_size, const int64_t t2_mid_size, + const int64_t low_dim, cudaStream_t stream, Dtype* out_data); } // namespace kernels } // namespace layers diff --git a/turbo_transformers/layers/kernels/gpu_utils_test.cpp b/turbo_transformers/layers/kernels/gpu_utils_test.cpp new file mode 100644 index 00000000..a313cee3 --- /dev/null +++ b/turbo_transformers/layers/kernels/gpu_utils_test.cpp @@ -0,0 +1,60 @@ +// Copyright (C) 2020 THL A29 Limited, a Tencent company. +// All rights reserved. +// Licensed under the BSD 3-Clause License (the "License"); you may +// not use this file except in compliance with the License. You may +// obtain a copy of the License at +// https://opensource.org/licenses/BSD-3-Clause +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" basis, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// See the AUTHORS file for names of contributors. + +#include "loguru.hpp" +#ifdef TT_WITH_CUDA +#include "turbo_transformers/core/cuda_device_context.h" +#include "turbo_transformers/core/tensor.h" +#include "turbo_transformers/layers/kernels/common.h" +#include "turbo_transformers/layers/kernels/utils.h" +#endif +#include "catch2/catch.hpp" +#include "turbo_transformers/core/enforce.h" + +namespace turbo_transformers { +namespace layers { +namespace kernels { + +#ifdef TT_WITH_CUDA +template +static void ConcatTestHelper(int batch_size, int dim1, int dim2, + int hidden_size, const Func& func) { + core::Tensor cpu_t1(nullptr), cpu_t2(nullptr), cpu_out(nullptr), + gpu_t1(nullptr), gpu_t2(nullptr), gpu_out(nullptr); + std::tie(cpu_t1, gpu_t1) = common::CreateAndFillRandomForCPUGPUTensors( + {batch_size, 2, dim1, hidden_size}); + std::tie(cpu_t2, gpu_t2) = common::CreateAndFillRandomForCPUGPUTensors( + {batch_size, 2, dim2, hidden_size}); + func(cpu_t1, cpu_t2, cpu_out, gpu_t1, gpu_t2, gpu_out); +} + +TEST_CASE("gpu-concat") { + for (auto hidden_size : {16}) { + for (auto batch_size : {1, 5}) { + ConcatTestHelper( + batch_size, 7, 11, hidden_size, + [](core::Tensor& cpu_t1, core::Tensor& cpu_t2, core::Tensor& cpu_out, + core::Tensor& gpu_t1, core::Tensor& gpu_t2, + core::Tensor& gpu_out) { + kernels::Concat(cpu_t1, cpu_t2, 2, &cpu_out); + kernels::Concat(gpu_t1, gpu_t2, 2, &gpu_out); + REQUIRE(common::CheckResultOfCPUAndGPU(cpu_out, gpu_out)); + }); + } + } +} +#endif + +} // namespace kernels +} // namespace layers +} // namespace turbo_transformers diff --git a/turbo_transformers/layers/kernels/layer_norm.cpp b/turbo_transformers/layers/kernels/layer_norm.cpp index c53c200e..8593a4be 100644 --- a/turbo_transformers/layers/kernels/layer_norm.cpp +++ b/turbo_transformers/layers/kernels/layer_norm.cpp @@ -20,6 +20,9 @@ #include "turbo_transformers/layers/kernels/common.h" #include "turbo_transformers/layers/kernels/gpu_layer_norm_kernel.h" #endif +#ifdef WITH_PERFTOOLS +#include "turbo_transformers/core/profiler.h" +#endif namespace turbo_transformers { namespace layers { @@ -28,7 +31,11 @@ static constexpr float g_epsilon = 1e-12; template void LayerNorm(const core::Tensor& gamma, const core::Tensor& beta, - core::Tensor* out_tensor) { + core::Tensor* out_tensor, T eps, const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, out_tensor->device_type()); +#endif TT_ENFORCE_EQ( common::is_same_device_ctx(gamma.device_ctx(), beta.device_ctx()), true, "LayerNorm gamma and beta must be on the same device context."); @@ -57,7 +64,7 @@ void LayerNorm(const core::Tensor& gamma, const core::Tensor& beta, mean = mean / feature_dim; var = var / feature_dim - mean * mean; - var = 1.f / sqrtf(var + g_epsilon); + var = 1.f / sqrtf(var + eps); #pragma omp simd for (int64_t i = 0; i < feature_dim; ++i) { @@ -78,18 +85,26 @@ void LayerNorm(const core::Tensor& gamma, const core::Tensor& beta, TT_THROW("AddBiasLayerNorm device_type %d is not supported", out_tensor->device_type()); } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, out_tensor->device_type()); +#endif } template void LayerNorm(const core::Tensor& gamma, const core::Tensor& beta, - core::Tensor* out_tensor); + core::Tensor* out_tensor, float eps, + const std::string name); template void AddBiasLayerNorm(const core::Tensor& input_tensor, const core::Tensor& bias_tensor, const core::Tensor& gamma_tensor, - const core::Tensor& beta_tensor, - core::Tensor* out_tensor) { + const core::Tensor& beta_tensor, core::Tensor* out_tensor, + T eps, const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, input_tensor.device_type()); +#endif TT_ENFORCE_EQ(common::is_same_device_ctx(input_tensor.device_ctx(), bias_tensor.device_ctx()), true, @@ -131,7 +146,7 @@ void AddBiasLayerNorm(const core::Tensor& input_tensor, mean = mean / n; var = var / n - mean * mean; - var = 1.f / sqrtf(var + g_epsilon); + var = 1.f / sqrtf(var + eps); #pragma omp simd for (int64_t i = 0; i < n; ++i) { @@ -151,13 +166,17 @@ void AddBiasLayerNorm(const core::Tensor& input_tensor, TT_THROW("LayerNorm device_type %d is not supported", input_tensor.device_type()); } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, input_tensor.device_type()); +#endif } template void AddBiasLayerNorm(const core::Tensor& input_tensor, const core::Tensor& bias_tensor, const core::Tensor& gamma_tensor, const core::Tensor& beta_tensor, - core::Tensor* out_tensor); + core::Tensor* out_tensor, float eps, + const std::string name); } // namespace kernels } // namespace layers diff --git a/turbo_transformers/layers/kernels/layer_norm.h b/turbo_transformers/layers/kernels/layer_norm.h index 1298b775..bba0aabf 100644 --- a/turbo_transformers/layers/kernels/layer_norm.h +++ b/turbo_transformers/layers/kernels/layer_norm.h @@ -22,14 +22,16 @@ namespace kernels { template extern void LayerNorm(const core::Tensor& gamma, const core::Tensor& beta, - core::Tensor* out_tensor); + core::Tensor* out_tensor, T eps = 1e-12, + const std::string name = "LayerNorm"); template extern void AddBiasLayerNorm(const core::Tensor& input_tensor, const core::Tensor& bias_tensor, const core::Tensor& gamma_tensor, const core::Tensor& beta_tensor, - core::Tensor* out_tensor); + core::Tensor* out_tensor, T eps = 1e-12, + const std::string name = "AddBiasLayerNorm"); } // namespace kernels } // namespace layers diff --git a/turbo_transformers/layers/kernels/layer_norm_test.cpp b/turbo_transformers/layers/kernels/layer_norm_test.cpp index 7f41595c..76c8c15b 100644 --- a/turbo_transformers/layers/kernels/layer_norm_test.cpp +++ b/turbo_transformers/layers/kernels/layer_norm_test.cpp @@ -53,9 +53,6 @@ TEST_CASE("add_bias_layer_norm-test") { std::tie(cpu_beta, gpu_beta) = common::CreateAndFillRandomForCPUGPUTensors({hidden_size}); - std::cout << "batch_size: " << batch_size - << " seq_length: " << seq_length - << " hidden_size: " << hidden_size; { LayerNorm(cpu_gamma, cpu_beta, &cpu_out); LayerNorm(gpu_gamma, gpu_beta, &gpu_out); diff --git a/turbo_transformers/layers/kernels/mat_mul.cpp b/turbo_transformers/layers/kernels/mat_mul.cpp index ef2492f2..7a083f23 100644 --- a/turbo_transformers/layers/kernels/mat_mul.cpp +++ b/turbo_transformers/layers/kernels/mat_mul.cpp @@ -20,12 +20,20 @@ #include "turbo_transformers/core/cuda_device_context.h" #include "turbo_transformers/core/cuda_enforce.cuh" #endif +#ifdef WITH_PERFTOOLS +#include "turbo_transformers/core/profiler.h" +#endif namespace turbo_transformers { namespace layers { namespace kernels { void MatMul(const core::Tensor& A, bool a_trans, const core::Tensor& B, - bool b_trans, float alpha, core::Tensor* out, float beta) { + bool b_trans, float alpha, core::Tensor* out, float beta, + const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, A.device_type()); +#endif BlasInt a_cols = A.shape(-1); BlasInt a_rows = A.numel() / a_cols; BlasInt b_cols = B.shape(-1); @@ -97,9 +105,17 @@ void MatMul(const core::Tensor& A, bool a_trans, const core::Tensor& B, } else { TT_THROW("device_type %d is not supported for MatMul", A.device_type()); } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, A.device_type()); +#endif } void BatchMatMul(const core::Tensor& A, bool a_trans, const core::Tensor& B, - bool b_trans, float alpha, core::Tensor* C, float beta) { + bool b_trans, float alpha, core::Tensor* C, float beta, + const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, A.device_type()); +#endif auto* A_shape = &A.shape(0); auto A_ndim = A.n_dim(); auto* B_shape = &B.shape(0); @@ -184,6 +200,9 @@ void BatchMatMul(const core::Tensor& A, bool a_trans, const core::Tensor& B, } else { TT_THROW("device_type %d is not supported!", A.device_type()); } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, A.device_type()); +#endif } } // namespace kernels diff --git a/turbo_transformers/layers/kernels/mat_mul.h b/turbo_transformers/layers/kernels/mat_mul.h index 265cbec6..bece84f4 100644 --- a/turbo_transformers/layers/kernels/mat_mul.h +++ b/turbo_transformers/layers/kernels/mat_mul.h @@ -17,10 +17,12 @@ namespace turbo_transformers { namespace layers { namespace kernels { extern void MatMul(const core::Tensor& A, bool a_trans, const core::Tensor& B, - bool b_trans, float alpha, core::Tensor* out, float beta); + bool b_trans, float alpha, core::Tensor* out, float beta, + const std::string name = "MatMul"); extern void BatchMatMul(const core::Tensor& A, bool a_trans, const core::Tensor& B, bool b_trans, float alpha, - core::Tensor* C, float beta); + core::Tensor* C, float beta, + const std::string name = "BatchMatMul"); } // namespace kernels } // namespace layers diff --git a/turbo_transformers/layers/kernels/mat_mul_test.cpp b/turbo_transformers/layers/kernels/mat_mul_test.cpp index 941e2be2..4c6940c0 100644 --- a/turbo_transformers/layers/kernels/mat_mul_test.cpp +++ b/turbo_transformers/layers/kernels/mat_mul_test.cpp @@ -123,7 +123,6 @@ TEST_CASE("matmul-gpu-test") { check_cpu_gpu_res(false); } #endif - } // namespace kernels } // namespace layers } // namespace turbo_transformers diff --git a/turbo_transformers/layers/kernels/softmax.cpp b/turbo_transformers/layers/kernels/softmax.cpp index b29d1447..698dff56 100644 --- a/turbo_transformers/layers/kernels/softmax.cpp +++ b/turbo_transformers/layers/kernels/softmax.cpp @@ -23,22 +23,40 @@ namespace turbo_transformers { namespace layers { namespace kernels { +// attr_mask's shape could be (batch, from_len, to_len), (batch, 1, to_len) or +// nullptr is2D is used to distinguish the two scenarios. void SoftmaxMask(float* qk_buf, const float* attr_mask, int64_t batch_size, - int64_t head_num, int64_t seq_len, float scale) { - int64_t M = batch_size * head_num * seq_len; - int64_t N = seq_len; + int64_t head_num, int64_t from_seq_len, int64_t to_seq_len, + float scale, bool is2D = true) { + int64_t M = batch_size * head_num * from_seq_len; + int64_t N = to_seq_len; #pragma omp parallel for for (int64_t i = 0; i < M; ++i) { auto* qk_buf_ptr = qk_buf + i * N; - auto attr_mask_offset = i / (head_num * seq_len) * seq_len; - auto attr_mask_ptr = attr_mask + attr_mask_offset; - // max-trick + if (attr_mask != nullptr) { + const float* attr_mask_ptr; + auto batch_idx = i / (head_num * from_seq_len); + auto from_seq_idx = i % from_seq_len; + attr_mask_ptr = + attr_mask + + (is2D ? batch_idx * to_seq_len + : (batch_idx * from_seq_len + from_seq_idx) * to_seq_len); +// max-trick #pragma omp simd - for (int64_t j = 0; j < N; ++j) { - auto mask_val = attr_mask_ptr[j]; - auto qk_val = qk_buf_ptr[j]; - qk_val = qk_val * scale + mask_val; - qk_buf_ptr[j] = qk_val; + for (int64_t j = 0; j < N; ++j) { + auto qk_val = qk_buf_ptr[j]; + auto mask_val = attr_mask_ptr[j]; + qk_val = qk_val * scale + mask_val; + qk_buf_ptr[j] = qk_val; + } + } else { +// max-trick +#pragma omp simd + for (int64_t j = 0; j < N; ++j) { + auto qk_val = qk_buf_ptr[j]; + qk_val = qk_val * scale; + qk_buf_ptr[j] = qk_val; + } } float max_val = std::numeric_limits::lowest(); #pragma omp simd reduction(max : max_val) @@ -61,19 +79,39 @@ void SoftmaxMask(float* qk_buf, const float* attr_mask, int64_t batch_size, } } } + void ApplyMaskAndSoftmax(core::Tensor* inout, const core::Tensor& att_mask, - float scale) { + float scale, const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, inout->device_type()); +#endif auto batch_size = inout->shape(0); auto num_att_heads = inout->shape(1); - auto seq_len = inout->shape(2); + auto from_seq_len = inout->shape(2); + auto to_seq_len = inout->shape(3); + bool is_2D = false; + if (!att_mask.is_null()) { + if (att_mask.n_dim() == 2 || + (att_mask.n_dim() == 3 && att_mask.shape(1) == 1) || + (att_mask.n_dim() == 4 && att_mask.shape(2) == 1)) { + is_2D = true; + } else { + is_2D = false; + } + } + const float* att_mask_data = nullptr; + if (!att_mask.is_null()) { + att_mask_data = att_mask.data(); + } if (inout->device_type() == kDLCPU) { - SoftmaxMask(inout->mutableData(), att_mask.data(), batch_size, - num_att_heads, seq_len, scale); + SoftmaxMask(inout->mutableData(), att_mask_data, batch_size, + num_att_heads, from_seq_len, to_seq_len, scale, is_2D); } else if (inout->device_type() == kDLGPU) { #ifdef TT_WITH_CUDA auto& cuda_ctx = core::CUDADeviceContext::GetInstance(); - GPUSoftmaxMask(inout->mutableData(), att_mask.data(), - batch_size, num_att_heads, seq_len, scale, + GPUSoftmaxMask(inout->mutableData(), att_mask_data, batch_size, + num_att_heads, from_seq_len, to_seq_len, scale, is_2D, cuda_ctx.stream()); #else TT_THROW("The current code is not compiled with CUDA."); @@ -81,6 +119,9 @@ void ApplyMaskAndSoftmax(core::Tensor* inout, const core::Tensor& att_mask, } else { TT_THROW("device_type is not supported"); } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, inout->device_type()); +#endif } } // namespace kernels diff --git a/turbo_transformers/layers/kernels/softmax.h b/turbo_transformers/layers/kernels/softmax.h index 4a7e6d0c..226f0877 100644 --- a/turbo_transformers/layers/kernels/softmax.h +++ b/turbo_transformers/layers/kernels/softmax.h @@ -20,7 +20,8 @@ namespace turbo_transformers { namespace layers { namespace kernels { extern void ApplyMaskAndSoftmax(core::Tensor* inout, - const core::Tensor& att_mask, float scale); + const core::Tensor& att_mask, float scale, + const std::string name = "ApplyMaskAndSoftmax"); } // namespace kernels } // namespace layers diff --git a/turbo_transformers/layers/kernels/softmax_test.cpp b/turbo_transformers/layers/kernels/softmax_test.cpp index 4a936699..992ee589 100644 --- a/turbo_transformers/layers/kernels/softmax_test.cpp +++ b/turbo_transformers/layers/kernels/softmax_test.cpp @@ -29,34 +29,67 @@ namespace layers { namespace kernels { #ifdef TT_WITH_CUDA -TEST_CASE("softmax-gpu-test") { +TEST_CASE("softmax-gpu-2Dmask-test") { int64_t num_attention_heads = 12; constexpr float scaler = 1.; std::vector batch_size_list{1, 20}; - std::vector seq_length_list{10, 20, 40, 60, 80, - 100, 200, 300, 400, 500}; + std::vector from_seq_list{10, 20, 40, 60, 80, + 100, 200, 300, 400, 500}; + std::vector to_seq_list{10, 20, 40, 60, 80, 100}; for (auto batch_size : batch_size_list) - for (auto seq_length : seq_length_list) { - core::Tensor qk_buf_cpu(nullptr), qk_buf_gpu(nullptr); - std::tie(qk_buf_cpu, qk_buf_gpu) = - common::CreateAndFillRandomForCPUGPUTensors( - {batch_size, num_attention_heads, seq_length, seq_length}); + for (auto from_seq : from_seq_list) + for (auto to_seq : to_seq_list) { + core::Tensor qk_buf_cpu(nullptr), qk_buf_gpu(nullptr); + std::tie(qk_buf_cpu, qk_buf_gpu) = + common::CreateAndFillRandomForCPUGPUTensors( + {batch_size, num_attention_heads, from_seq, to_seq}); - core::Tensor attr_mask_cpu(nullptr), attr_mask_gpu(nullptr); - std::tie(attr_mask_cpu, attr_mask_gpu) = - common::CreateAndFillRandomForCPUGPUTensors( - {batch_size, seq_length}); + core::Tensor attr_mask_cpu(nullptr), attr_mask_gpu(nullptr); + std::tie(attr_mask_cpu, attr_mask_gpu) = + common::CreateAndFillRandomForCPUGPUTensors( + {batch_size, 1, 1, to_seq}); - ApplyMaskAndSoftmax(&qk_buf_gpu, attr_mask_gpu, scaler); + ApplyMaskAndSoftmax(&qk_buf_gpu, attr_mask_gpu, scaler); - ApplyMaskAndSoftmax(&qk_buf_cpu, attr_mask_cpu, scaler); + ApplyMaskAndSoftmax(&qk_buf_cpu, attr_mask_cpu, scaler); - REQUIRE(common::CheckResultOfCPUAndGPU(qk_buf_cpu, qk_buf_gpu)); - } + REQUIRE(common::CheckResultOfCPUAndGPU(qk_buf_cpu, qk_buf_gpu)); + } } + +TEST_CASE("softmax-gpu-3Dmask-test") { + int64_t num_attention_heads = 12; + + constexpr float scaler = 1.; + + std::vector batch_size_list{1, 20}; + std::vector from_seq_list{10, 20, 40, 60, 80, + 100, 200, 300, 400, 500}; + std::vector to_seq_list{10, 20, 40, 60, 80, 100}; + for (auto batch_size : batch_size_list) + for (auto from_seq : from_seq_list) + for (auto to_seq : to_seq_list) { + core::Tensor qk_buf_cpu(nullptr), qk_buf_gpu(nullptr); + std::tie(qk_buf_cpu, qk_buf_gpu) = + common::CreateAndFillRandomForCPUGPUTensors( + {batch_size, num_attention_heads, from_seq, to_seq}); + + core::Tensor attr_mask_cpu(nullptr), attr_mask_gpu(nullptr); + std::tie(attr_mask_cpu, attr_mask_gpu) = + common::CreateAndFillRandomForCPUGPUTensors( + {batch_size, 1, from_seq, to_seq}); + + ApplyMaskAndSoftmax(&qk_buf_gpu, attr_mask_gpu, scaler); + + ApplyMaskAndSoftmax(&qk_buf_cpu, attr_mask_cpu, scaler); + + REQUIRE(common::CheckResultOfCPUAndGPU(qk_buf_cpu, qk_buf_gpu)); + } +} + #endif } // namespace kernels diff --git a/turbo_transformers/layers/kernels/transpose.cpp b/turbo_transformers/layers/kernels/transpose.cpp index b2bc546e..bfa5ec90 100644 --- a/turbo_transformers/layers/kernels/transpose.cpp +++ b/turbo_transformers/layers/kernels/transpose.cpp @@ -21,11 +21,40 @@ #include "turbo_transformers/core/cuda_device_context.h" #include "turbo_transformers/layers/kernels/gpu_transpose_kernel.h" #endif +#ifdef WITH_PERFTOOLS +#include "turbo_transformers/core/profiler.h" +#endif namespace turbo_transformers { namespace layers { namespace kernels { +// tranpose(2,3) +// input (B x seq_len x head_num * hidden_size) +// bias (head * hidden_size) +static void AddBiasTransposeForScoreImpl(const float* input, const float* bias, + int64_t dim0, int64_t dim1, + int64_t dim2, int64_t dim3, + float* output) { +#pragma omp parallel for + for (int64_t idx = 0; idx < dim0 * dim1; ++idx) { + int64_t dim0_idx = idx / dim1; + int64_t dim1_idx = idx % dim1; + for (int64_t dim2_idx = 0; dim2_idx < dim2; ++dim2_idx) { + const float* bias_ptr = bias + dim2_idx * dim3; + auto* src = input + dim0_idx * (dim1 * dim2 * dim3) + + dim1_idx * dim2 * dim3 + dim2_idx * dim3; + auto* dst = output + dim0_idx * (dim1 * dim2 * dim3) + + dim2_idx * dim1 * dim3 + dim1_idx * dim3; +#pragma omp simd + for (int64_t dim3_idx = 0; dim3_idx < dim3; ++dim3_idx) { + dst[dim3_idx] = src[dim3_idx] + bias_ptr[dim3_idx]; + } + } + } +} + +// tranpose(2,3) static void TransposeForScoreImpl(float* output, const float* input, int64_t batch_size, int64_t seq_length, int64_t num_attention_heads, int64_t width) { @@ -48,7 +77,12 @@ static void TransposeForScoreImpl(float* output, const float* input, } } -void TransposeForScore(core::Tensor* output, const core::Tensor& input) { +void TransposeForScore(core::Tensor* output, const core::Tensor& input, + const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, input.device_type()); +#endif if (input.device_type() == kDLCPU && output->device_type() == kDLCPU) { TransposeForScoreImpl(output->mutableData(), input.data(), output->shape(0), output->shape(1), input.shape(1), @@ -60,24 +94,69 @@ void TransposeForScore(core::Tensor* output, const core::Tensor& input) { auto num_attention_heads = input.shape(1); auto width = input.shape(3); core::CUDADeviceContext& cuda_ctx = core::CUDADeviceContext::GetInstance(); - GPUTransposeForScore( - input.data(), output->mutableData(), batch_size, - seq_length, num_attention_heads, width, cuda_ctx.stream()); + const float* dummy = nullptr; + GPUTransposeForScore( + input.data(), dummy, batch_size, seq_length, num_attention_heads, + width, cuda_ctx.stream(), output->mutableData()); +#endif + } else { + TT_THROW("device_type is not supported"); + } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, input.device_type()); +#endif +} + +// add bias and transpose(2,3) +void AddBiasTransposeForScore(const core::Tensor& input, + const core::Tensor& bias, core::Tensor* output, + const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, input.device_type()); +#endif + TT_ENFORCE_EQ(input.n_dim(), 4, "input should be a 4-D tensor"); + TT_ENFORCE_EQ(bias.numel(), input.shape(2) * input.shape(3), + "bias shape %d should be %d x %d", bias.n_dim(), input.shape(2), + input.shape(3)); + auto dim0 = input.shape(0); + auto dim1 = input.shape(1); + auto dim2 = input.shape(2); + auto dim3 = input.shape(3); + if (input.device_type() == kDLCPU && output->device_type() == kDLCPU) { + AddBiasTransposeForScoreImpl(input.data(), bias.data(), dim0, + dim1, dim2, dim3, + output->mutableData()); + } else if (input.device_type() == kDLGPU && output->device_type() == kDLGPU) { +#ifdef TT_WITH_CUDA + core::CUDADeviceContext& cuda_ctx = core::CUDADeviceContext::GetInstance(); + GPUTransposeForScore(input.data(), bias.data(), + dim0, dim1, dim2, dim3, cuda_ctx.stream(), + output->mutableData()); #endif } else { TT_THROW("device_type is not supported"); } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, input.device_type()); +#endif } void SplitAddBiasTransposeForScore(core::Tensor* output_tensor, const core::Tensor& input_tensor, - const core::Tensor& bias_tensor) { + const core::Tensor& bias_tensor, + const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, input_tensor.device_type()); +#endif + TT_ENFORCE_EQ(output_tensor->n_dim(), 5, "output_tensor should be (weight_num, batch_size, seq_length, " "num_attention_heads, size_per_head)"); - TT_ENFORCE_EQ(output_tensor->shape(0), 3, - "output_tensor should be (3, batch_size, seq_length, " - "num_attention_heads, size_per_head)"); + // TT_ENFORCE_EQ(bias_tensor.n_dim(), 1, + // "output_tensor should be (weight_num * num_attention_heads, " + // "size_per_head)"); auto batch_size = output_tensor->shape(1); auto seq_length = output_tensor->shape(3); @@ -140,6 +219,9 @@ void SplitAddBiasTransposeForScore(core::Tensor* output_tensor, } else { TT_THROW("device_type is not supported"); } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, input_tensor.device_type()); +#endif } } // namespace kernels diff --git a/turbo_transformers/layers/kernels/transpose.h b/turbo_transformers/layers/kernels/transpose.h index 1af48c5d..fb8e0f7b 100644 --- a/turbo_transformers/layers/kernels/transpose.h +++ b/turbo_transformers/layers/kernels/transpose.h @@ -34,14 +34,20 @@ namespace kernels { output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) return output_tensor * **/ -extern void TransposeForScore(core::Tensor* output, const core::Tensor& input); +extern void TransposeForScore(core::Tensor* output, const core::Tensor& input, + const std::string name = "TransposeForScore"); + +extern void AddBiasTransposeForScore( + const core::Tensor& input, const core::Tensor& bias, core::Tensor* output, + const std::string name = "AddBiasTransposeForScore"); // input: (batch_size, seq_length, 3, head_num, *size_per_head) // bias: (3, head_num, size_per_head) // output: (3, batch_size, num_attention_heads, seq_length, size_per_head) -extern void SplitAddBiasTransposeForScore(core::Tensor* output, - const core::Tensor& input_tensor, - const core::Tensor& bias_tensor); +extern void SplitAddBiasTransposeForScore( + core::Tensor* output, const core::Tensor& input_tensor, + const core::Tensor& bias_tensor, + const std::string name = "SplitAddBiasTransposeForScore"); } // namespace kernels } // namespace layers diff --git a/turbo_transformers/layers/kernels/transpose_test.cpp b/turbo_transformers/layers/kernels/transpose_test.cpp index e510d8dc..91f36497 100644 --- a/turbo_transformers/layers/kernels/transpose_test.cpp +++ b/turbo_transformers/layers/kernels/transpose_test.cpp @@ -93,6 +93,42 @@ TEST_CASE("transpose-gpu-test") { output_tensor_gpu)); } } + +TEST_CASE("transpose-bias-gpu-test") { + const std::vector num_attention_heads_list{12, 20, 24}; + const std::vector batch_size_list{ + 1, + 20, + }; + const std::vector seq_length_list{10, 32, 128}; + + for (auto num_attention_heads : num_attention_heads_list) + for (auto batch_size : batch_size_list) + for (auto seq_length : seq_length_list) { + core::Tensor input_tensor_cpu(nullptr), input_tensor_gpu(nullptr); + core::Tensor bias_tensor_cpu(nullptr), bias_tensor_gpu(nullptr); + std::tie(input_tensor_cpu, input_tensor_gpu) = + common::CreateAndFillRandomForCPUGPUTensors( + {batch_size, seq_length, num_attention_heads, 64}); + std::tie(bias_tensor_cpu, bias_tensor_gpu) = + common::CreateAndFillRandomForCPUGPUTensors( + {num_attention_heads, 64}); + + turbo_transformers::core::Tensor output_tensor_gpu( + turbo_transformers::core::NewDLPackTensorT( + {batch_size, num_attention_heads, seq_length, 64}, kDLGPU, 0)); + turbo_transformers::core::Tensor output_tensor_cpu( + turbo_transformers::core::NewDLPackTensorT( + {batch_size, num_attention_heads, seq_length, 64}, kDLCPU, 0)); + + AddBiasTransposeForScore(input_tensor_gpu, bias_tensor_gpu, + &output_tensor_gpu); + AddBiasTransposeForScore(input_tensor_cpu, bias_tensor_cpu, + &output_tensor_cpu); + REQUIRE(common::CheckResultOfCPUAndGPU(output_tensor_cpu, + output_tensor_gpu)); + } +} #endif } // namespace kernels diff --git a/turbo_transformers/layers/kernels/utils.cpp b/turbo_transformers/layers/kernels/utils.cpp new file mode 100644 index 00000000..d436ed08 --- /dev/null +++ b/turbo_transformers/layers/kernels/utils.cpp @@ -0,0 +1,173 @@ +// Copyright (C) 2020 THL A29 Limited, a Tencent company. +// All rights reserved. +// Licensed under the BSD 3-Clause License (the "License"); you may +// not use this file except in compliance with the License. You may +// obtain a copy of the License at +// https://opensource.org/licenses/BSD-3-Clause +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" basis, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// See the AUTHORS file for names of contributors. + +#include "utils.h" + +#include "common.h" +#ifdef TT_WITH_CUDA +#include + +#include "turbo_transformers/core/cuda_device_context.h" +#include "turbo_transformers/core/cuda_enforce.cuh" +#include "turbo_transformers/layers/kernels/gpu_utils.h" +#endif +#ifdef WITH_PERFTOOLS +#include "turbo_transformers/core/profiler.h" +#endif + +namespace turbo_transformers { +namespace layers { +namespace kernels { + +template +void Concat(const core::Tensor& t1, const core::Tensor& t2, size_t dim, + core::Tensor* output, const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, t1.device_type()); +#endif + TT_ENFORCE(t1.n_dim() >= dim && t2.n_dim() >= dim, + "concatation of two tensors with dim as %d and %d is illegal.", + t1.n_dim(), t2.n_dim()); + + auto t1_size = t1.shape(dim); + auto t2_size = t2.shape(dim); + + std::vector output_shape; + for (size_t i = 0; i < t1.n_dim(); i++) { + if (i != dim) { + TT_ENFORCE( + t1.shape(i) == t2.shape(i), + "concatation of two tensors illegal, at dim %d size is %d vs %d", i, + t1.shape(i), t2.shape(i)); + output_shape.push_back(t1.shape(i)); + } else { + output_shape.push_back(t1_size + t2_size); + } + } + + int64_t high_dim = 1; + for (size_t i = 0; i < dim; i++) { + high_dim *= t1.shape(i); + } + + size_t low_dim = 1; + for (size_t i = t1.n_dim() - 1; i > dim; i--) { + low_dim *= t2.shape(i); + } + + output->Reshape(output_shape, t1.device_type(), t1.device_id(), + "Concat/Reshape"); + if (t1.device_type() == kDLGPU) { +#ifdef TT_WITH_CUDA + core::CUDADeviceContext& cuda_ctx = core::CUDADeviceContext::GetInstance(); + GPUConcat(t1.data(), t2.data(), high_dim, t1_size, t2_size, + low_dim, cuda_ctx.stream(), output->mutableData()); +#endif + } else if (t1.device_type() == kDLCPU) { +#pragma omp parallel for + for (int64_t i = 0; i < high_dim; ++i) { + for (int64_t j = 0; j < t1_size; ++j) { + core::Copy( + t1.data() + (i * t1_size + j) * low_dim, low_dim, + t1.device_type(), output->device_type(), + output->mutableData() + (i * (t1_size + t2_size) + j) * low_dim); + } + for (int64_t j = 0; j < t2_size; ++j) { + core::Copy(t2.data() + (i * t2_size + j) * low_dim, low_dim, + t1.device_type(), output->device_type(), + output->mutableData() + + (i * (t1_size + t2_size) + t1_size + j) * low_dim); + } + } + } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, t1.device_type()); +#endif +} + +template void Concat(const core::Tensor& t1, const core::Tensor& t2, + size_t dim, core::Tensor* output, + const std::string name); + +void AddBias(const core::Tensor& bias, core::Tensor* output, + const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, bias.device_type()); +#endif + auto dim1 = bias.shape(0); + auto dim0 = output->numel() / dim1; + auto output_data = output->mutableData(); + const auto bias_data = bias.data(); + if (bias.device_type() == kDLCPU && output->device_type() == kDLCPU) { +#pragma omp parallel for + for (int64_t i = 0; i < dim0; ++i) { +#pragma omp simd + for (int64_t j = 0; j < dim1; ++j) { + output_data[i * dim1 + j] += bias_data[j]; + } + } + } else { +#ifdef TT_WITH_CUDA + core::CUDADeviceContext& cuda_ctx = core::CUDADeviceContext::GetInstance(); + const float* dummy{nullptr}; + kernels::GPUAddBias(output_data, dummy, bias_data, dim0, dim1, + cuda_ctx.stream(), output_data); +#endif + } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, bias.device_type()); +#endif +} + +void AddInputBias(const core::Tensor& input1, const core::Tensor& input2, + const core::Tensor& bias, core::Tensor* output, + const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, input1.device_type()); +#endif + TT_ENFORCE_EQ(input1.numel(), input2.numel(), + "Tensor input1 and Tensor input2 should have the same numel."); + auto dim1 = bias.shape(0); + auto dim0 = output->numel() / dim1; + auto output_data = output->mutableData(); + const auto bias_data = bias.data(); + const auto input1_data = input1.data(); + const auto input2_data = input2.data(); + + if (input1.device_type() == kDLCPU && output->device_type() == kDLCPU) { +#pragma omp parallel for + for (int64_t i = 0; i < dim0; ++i) { +#pragma omp simd + for (int64_t j = 0; j < dim1; ++j) { + output_data[i * dim1 + j] = bias_data[j] + input1_data[i * dim1 + j] + + input2_data[i * dim1 + j]; + } + } + } else { +#ifdef TT_WITH_CUDA + core::CUDADeviceContext& cuda_ctx = core::CUDADeviceContext::GetInstance(); + GPUAddBias(input1_data, input2_data, bias_data, dim0, dim1, + cuda_ctx.stream(), output_data); +#endif + } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, input1.device_type()); +#endif +} + +} // namespace kernels +} // namespace layers +} // namespace turbo_transformers diff --git a/turbo_transformers/layers/kernels/utils.h b/turbo_transformers/layers/kernels/utils.h new file mode 100644 index 00000000..ec2c2190 --- /dev/null +++ b/turbo_transformers/layers/kernels/utils.h @@ -0,0 +1,31 @@ +// Copyright (C) 2020 THL A29 Limited, a Tencent company. +// All rights reserved. +// Licensed under the BSD 3-Clause License (the "License"); you may +// not use this file except in compliance with the License. You may +// obtain a copy of the License at +// https://opensource.org/licenses/BSD-3-Clause +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" basis, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// See the AUTHORS file for names of contributors. + +#pragma once +#include "turbo_transformers/core/tensor.h" +namespace turbo_transformers { +namespace layers { +namespace kernels { + +void AddBias(const core::Tensor& bias, core::Tensor* output, + const std::string name = "Concat"); +void AddInputBias(const core::Tensor& input1, const core::Tensor& input2, + const core::Tensor& bias, core::Tensor* output, + const std::string name = "Concat"); +template +void Concat(const core::Tensor& t1, const core::Tensor& t2, size_t dim, + core::Tensor* output, const std::string name = "Concat"); + +} // namespace kernels +} // namespace layers +} // namespace turbo_transformers diff --git a/turbo_transformers/layers/kernels/utils_test.cpp b/turbo_transformers/layers/kernels/utils_test.cpp new file mode 100644 index 00000000..f8c2716e --- /dev/null +++ b/turbo_transformers/layers/kernels/utils_test.cpp @@ -0,0 +1,66 @@ +// Copyright (C) 2020 THL A29 Limited, a Tencent company. +// All rights reserved. +// Licensed under the BSD 3-Clause License (the "License"); you may +// not use this file except in compliance with the License. You may +// obtain a copy of the License at +// https://opensource.org/licenses/BSD-3-Clause +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" basis, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// See the AUTHORS file for names of contributors. +#include "turbo_transformers/layers/kernels/utils.h" + +#include "catch2/catch.hpp" + +namespace turbo_transformers { +namespace layers { +namespace kernels { + +TEST_CASE("cpu-concate", "test1") { + turbo_transformers::core::Tensor t1( + turbo_transformers::core::NewDLPackTensorT({2, 1, 2, 2})); + turbo_transformers::core::Tensor t2( + turbo_transformers::core::NewDLPackTensorT({2, 1, 3, 2})); + for (int i = 0; i < t1.numel(); ++i) { + t1.mutableData()[i] = i * 1.0; + } + + for (int i = 0; i < t2.numel(); ++i) { + t2.mutableData()[i] = i * 100.0; + } + turbo_transformers::core::Tensor res1(nullptr), res2(nullptr), res3(nullptr); + Concat(t1, t2, 2, &res1); + // res1.Print(std::cerr); + + turbo_transformers::core::Tensor t3( + turbo_transformers::core::NewDLPackTensorT({4, 2})); + turbo_transformers::core::Tensor t4( + turbo_transformers::core::NewDLPackTensorT({3, 2})); + Concat(t3, t4, 0, &res2); + + turbo_transformers::core::Tensor t5( + turbo_transformers::core::NewDLPackTensorT({2, 3})); + turbo_transformers::core::Tensor t6( + turbo_transformers::core::NewDLPackTensorT({2, 4})); + for (int i = 0; i < t5.numel(); ++i) { + t5.mutableData()[i] = i * 1.0; + } + + for (int i = 0; i < t6.numel(); ++i) { + t6.mutableData()[i] = i * 100.0; + } + Concat(t5, t6, 1, &res2); + // res3.Print(std::cerr); + + // t5.Print(std::cerr); + // t6.Print(std::cerr); + // res3.Print(std::cerr); + // REQUIRE(test_tensor.n_dim() == 2); + // REQUIRE(test_tensor.numel() == 3 * 4); +} + +} // namespace kernels +} // namespace layers +} // namespace turbo_transformers diff --git a/turbo_transformers/layers/multi_headed_attention.cpp b/turbo_transformers/layers/multi_headed_attention.cpp new file mode 100644 index 00000000..49788807 --- /dev/null +++ b/turbo_transformers/layers/multi_headed_attention.cpp @@ -0,0 +1,333 @@ +// Copyright (C) 2020 THL A29 Limited, a Tencent company. +// All rights reserved. +// Licensed under the BSD 3-Clause License (the "License"); you may +// not use this file except in compliance with the License. You may +// obtain a copy of the License at +// https://opensource.org/licenses/BSD-3-Clause +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" basis, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// See the AUTHORS file for names of contributors. + +#include "turbo_transformers/layers/multi_headed_attention.h" + +#include "loguru.hpp" +#include "turbo_transformers/core/memory.h" +#include "turbo_transformers/layers/kernels/common.h" +#include "turbo_transformers/layers/kernels/layer_norm.h" +#include "turbo_transformers/layers/kernels/mat_mul.h" +#include "turbo_transformers/layers/kernels/softmax.h" +#include "turbo_transformers/layers/kernels/transpose.h" +#include "turbo_transformers/layers/kernels/utils.h" + +#ifdef WITH_PERFTOOLS +#include "turbo_transformers/core/profiler.h" +#endif + +namespace turbo_transformers { +namespace layers { + +static std::mutex mutex_; + +void MultiHeadedAttention::operator()( + const core::Tensor& key_tensor, const core::Tensor& value_tensor, + const core::Tensor& query_tensor, const core::Tensor& attention_mask, + const std::string& attn_type, core::Tensor* output, core::Tensor* att_score, + std::unordered_map layer_cache, + bool pre_layernorm, bool post_layernorm, bool post_add_input, + bool is_trans_weight) const { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile("MultiHeadedAttention_" + attn_type, + query_tensor.device_type()); +#endif + std::lock_guard g(mutex_); + + TT_ENFORCE_EQ(key_tensor.n_dim(), 3, + "The key_tensor should be a matrix with shape [batch_size, " + "key_seq_len, hidden_size]."); + TT_ENFORCE_EQ(value_tensor.n_dim(), 3, + "The value_tensor should be a matrix with shape [batch_size, " + "key_seq_len, hidden_size]."); + TT_ENFORCE_EQ(query_tensor.n_dim(), 3, + "The query_tensors should be a matrix with shape [batch_size, " + "query_seq_len, hidden_size]."); + TT_ENFORCE_EQ( + key_tensor.shape(0), value_tensor.shape(0), + "The key_tensor and value_tensor should have the same hidden_size"); + + EnforceShapeAndType(); + auto batch_size = query_tensor.shape(0); + auto query_seq_length = + query_tensor.shape(1); // query_seq_length = from_seq_Len + + int64_t key_seq_length; + if (attn_type == "context") { + key_seq_length = key_tensor.shape(1); + } else if (attn_type == "self") { + key_seq_length = query_seq_length; + } else { + TT_THROW("attn_type should be context or self."); + } + + auto hidden_size = query_tensor.shape(2); + auto size_per_head = hidden_size / num_attention_heads_; + auto devtype = query_tensor.device_type(); + auto devid = query_tensor.device_id(); + + // TODO we should caching allocate intermediate tensor. + core::Tensor *q_ptr{nullptr}, *k_ptr{nullptr}, *v_ptr{nullptr}; + core::Tensor q_out1(nullptr); + core::Tensor v_out1(nullptr); + core::Tensor k_out1(nullptr); + core::Tensor q_out2(nullptr); + core::Tensor v_out2(nullptr); + core::Tensor k_out2(nullptr); + core::Tensor qkv_out1(nullptr); + core::Tensor qkv_out2(nullptr); + + bool layer_cache_not_none = layer_cache.size() > 0 ? true : false; + bool memory_keys_not_none = false, memory_values_not_none = false, + self_keys_not_none = false, self_values_not_none = false; + if (layer_cache_not_none) { + for (auto it = layer_cache.begin(); it != layer_cache.end(); ++it) { + if (it->first == "memory_keys" && !it->second->is_null()) { + memory_keys_not_none = true; + } + if (it->first == "memory_values" && !it->second->is_null()) { + memory_values_not_none = true; + } + if (it->first == "self_keys" && !it->second->is_null()) { + self_keys_not_none = true; + } + if (it->first == "self_values" && !it->second->is_null()) { + self_values_not_none = true; + } + } + } + bool memory_not_none = memory_values_not_none && memory_keys_not_none; + if (attn_type == "context") { + TT_ENFORCE_EQ(kernels::common::is_same_device_ctx( + query_tensor.device_ctx(), value_tensor.device_ctx()), + true, + "The query_tensor and value_tensor should have the same " + "device type and device id."); + TT_ENFORCE_EQ(kernels::common::is_same_device_ctx(query_tensor.device_ctx(), + key_tensor.device_ctx()), + true, + "The query_tensor and key_tensor should have the same " + "device type and device id."); + + q_out1.Reshape({batch_size, query_seq_length, hidden_size}, devtype, + devid, "context/gemm0/q_out1/Reshape"); + if (pre_layernorm) { + q_out2.Reshape({batch_size, query_seq_length, hidden_size}, + devtype, devid, "context/gemm0/q_out2/Reshape"); + core::Copy(query_tensor, q_out2, + "context/gemm0/prelayernorm/Copy"); + kernels::LayerNorm( + layernorm_gamma_, layernorm_beta_, &q_out2, 1e-6, + "context/gemm0/prelayernorm"); // q_out2 here is used as + // layernormed_query TODO(jiaruifang) + // 1e-6 should not be hard-coded + kernels::MatMul(q_out2, false, q_weight_, is_trans_weight, 1.0, &q_out1, + 0.0, "context/gemm0"); + } else { + kernels::MatMul(query_tensor, false, q_weight_, is_trans_weight, 1.0, + &q_out1, 0.0, "context/gemm0"); + } + q_out1.Reshape( + {batch_size, query_seq_length, num_attention_heads_, size_per_head}, + devtype, devid, "context/AddBiasTransposeForScore/q_out1/Reshape"); + q_out2.Reshape( + {batch_size, num_attention_heads_, query_seq_length, size_per_head}, + devtype, devid, "context/AddBiasTransposeForScore/q_out2/Reshape"); + kernels::AddBiasTransposeForScore(q_out1, q_bias_, &q_out2, + "context/AddBiasTransposeForScore"); + q_ptr = &q_out2; // point to static memory space + if (memory_not_none) { + v_ptr = layer_cache["memory_values"]; + k_ptr = layer_cache["memory_keys"]; + } else { + v_out1.Reshape({batch_size, key_seq_length, hidden_size}, devtype, + devid, "context/gemm1/v_out1/Reshape"); + k_out1.Reshape({batch_size, key_seq_length, hidden_size}, devtype, + devid, "context/gemm2/k_out1/Reshape"); + + kernels::MatMul(key_tensor, false, k_weight_, is_trans_weight, 1.0, + &k_out1, 0.0, "context/gemm1"); + kernels::MatMul(value_tensor, false, v_weight_, is_trans_weight, 1.0, + &v_out1, 0.0, "context/gemm2"); + v_out1.Reshape( + {batch_size, key_seq_length, num_attention_heads_, size_per_head}, + devtype, devid, "context/gemm1/v_out1/Reshape"); + k_out1.Reshape( + {batch_size, key_seq_length, num_attention_heads_, size_per_head}, + devtype, devid, "context/gemm2/k_out1/Reshape"); + + if (layer_cache_not_none) { + layer_cache["memory_keys"]->Reshape( + {batch_size, num_attention_heads_, key_seq_length, size_per_head}, + devtype, devid, "context/keys/AddBiasTransposeForScore/Reshape"); + layer_cache["memory_values"]->Reshape( + {batch_size, num_attention_heads_, key_seq_length, size_per_head}, + devtype, devid, "context/values/AddBiasTransposeForScore/reshape"); + kernels::AddBiasTransposeForScore( + v_out1, v_bias_, layer_cache["memory_values"], + "context/values/AddBiasTransposeForScore"); + kernels::AddBiasTransposeForScore( + k_out1, k_bias_, layer_cache["memory_keys"], + "context/keys/AddBiasTransposeForScore"); + v_ptr = layer_cache["memory_values"]; + k_ptr = layer_cache["memory_keys"]; + } else { + v_out2.Reshape( + {batch_size, num_attention_heads_, key_seq_length, size_per_head}, + devtype, devid, "context/values/AddBiasTransposeForScore/Reshape"); + k_out2.Reshape( + {batch_size, num_attention_heads_, key_seq_length, size_per_head}, + devtype, devid, "context/keys/AddBiasTransposeForScore/Reshape"); + kernels::AddBiasTransposeForScore( + v_out1, v_bias_, &v_out2, + "context/values/AddBiasTransposeForScore"); + kernels::AddBiasTransposeForScore( + k_out1, k_bias_, &k_out2, "context/keys/AddBiasTransposeForScore"); + v_ptr = &v_out2; + k_ptr = &k_out2; + } + } // else + } else if (attn_type == "self") { + qkv_out1.Reshape({3, batch_size, query_seq_length, hidden_size}, + devtype, devid, "self/qkv_out1/Reshape"); + if (pre_layernorm) { + core::Tensor layernormed_query(nullptr); + layernormed_query.Reshape( + {batch_size, query_seq_length, hidden_size}, devtype, devid, + "self/layernorm/Reshape"); + core::Copy(query_tensor, layernormed_query, "self/layernorm/Copy"); + kernels::LayerNorm(layernorm_gamma_, layernorm_beta_, + &layernormed_query, 1e-6); + kernels::MatMul(layernormed_query, false, qkv_weight_, is_trans_weight, + 1.0, &qkv_out1, 0.0, "self/gemm012_fused"); + } else { + kernels::MatMul(query_tensor, false, qkv_weight_, is_trans_weight, 1.0, + &qkv_out1, 0.0, "self/gemm012_fused"); + } + qkv_out2.Reshape( + {3, batch_size, num_attention_heads_, query_seq_length, size_per_head}, + devtype, devid, "self/qkv_out2/Reshape"); + kernels::SplitAddBiasTransposeForScore( + &qkv_out2, qkv_out1, qkv_bias_, "self/SplitAddBiasTransposeForScore"); + q_ptr = + new core::Tensor(qkv_out2[0]); // copy temporary tensor to heap space. + if (self_keys_not_none) { + kernels::Concat(*layer_cache["self_keys"], qkv_out2[1], 2, &k_out2, + "self/keys/Concat"); + k_ptr = &k_out2; + } else { + k_ptr = new core::Tensor(qkv_out2[1]); + } + if (self_values_not_none) { + kernels::Concat(*layer_cache["self_values"], qkv_out2[2], 2, + &v_out2, "self/values/Concat"); + v_ptr = &v_out2; + } else { + v_ptr = new core::Tensor(qkv_out2[2]); + } + if (layer_cache_not_none) { + layer_cache["self_keys"]->Reshape( + {batch_size, num_attention_heads_, k_ptr->shape(2), size_per_head}, + devtype, devid, "self/self_key/Reshape"); + layer_cache["self_values"]->Reshape( + {batch_size, num_attention_heads_, v_ptr->shape(2), size_per_head}, + devtype, devid, "self/self_value/Reshape"); + + core::Copy(*k_ptr, *layer_cache["self_keys"], + "self/self_key/Copy"); + core::Copy(*v_ptr, *layer_cache["self_values"], + "self/self_value/Copy"); + } + } else { + TT_THROW("%s is not support in MultiHeadedAttention\n", attn_type); + } // if (attn_type == "context") + // 2) Calculate and scale scores. + key_seq_length = k_ptr->shape( + 2); // update for self type attn, since it will concat with cache. + att_score->Reshape( + {batch_size, num_attention_heads_, query_seq_length, + key_seq_length}, // query_seq_length = from_seq_Len + devtype, devid, "batch_gemm3/Reshape"); + + const float scaler = 1.0f / std::sqrt(static_cast(size_per_head)); + kernels::BatchMatMul(*q_ptr, false, *k_ptr, true, scaler, att_score, 0.0, + "batch_gemm3"); //(B, num_head, q_len, k_len) + // mask = mask.unsqueeze(1) # [B, 1, 1, T_values] + // scores = scores.masked_fill(mask, -1e18) + // attn = self.softmax(scores).to(query.dtype) + kernels::ApplyMaskAndSoftmax( + att_score, + attention_mask, //(B, q_len, k_len) or (B, 1, k_len) + 1.0, "ApplyMaskAndSoftmax"); + + // context_original = torch.matmul(drop_attn, value) + core::Tensor context_layer(nullptr); + context_layer.Reshape( + {batch_size, num_attention_heads_, query_seq_length, size_per_head}, + devtype, devid, "ApplyMaskAndSoftmax/Reshape"); + + kernels::BatchMatMul(*att_score, false, *v_ptr, false, 1.0, &context_layer, + 0.0, "batch_gemm4"); + // context = unshape(context_original) + core::Tensor self_attr_out(nullptr); + + self_attr_out.Reshape( + {batch_size, query_seq_length, num_attention_heads_ * size_per_head}, + devtype, devid, "batch_gemm4/Reshape"); + kernels::TransposeForScore(&self_attr_out, context_layer, + "TransposeForScore"); + // output = self.final_linear(context) + output->Reshape({batch_size, query_seq_length, hidden_size}, devtype, + devid, "gemm5/Reshape"); + + kernels::MatMul(self_attr_out, false, dense_weight_, is_trans_weight, 1.0, + output, 0.0, "gemm5"); + + if (false == post_add_input) { + if (false == post_layernorm) { + //+bias + kernels::AddBias(dense_bias_, output, "AddBias"); + } else { + //+bias+layernorm + kernels::AddBiasLayerNorm(query_tensor, dense_bias_, + layernorm_gamma_, // gemma + layernorm_beta_, output, 1e-12, + "AddBiasLayerNorm"); + } + } else { + //+input + bias + kernels::AddInputBias(*output, query_tensor, dense_bias_, output); + } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile("MultiHeadedAttention_" + attn_type, devtype); +#endif +} + +void MultiHeadedAttention::EnforceShapeAndType() const { + if (loguru::current_verbosity_cutoff() >= 3) { + std::ostringstream os; + os << ">>>>>>>>>>>> qkv_weight_ <<<<<<<<<<<<" << std::endl; + q_weight_.Print(os); + os << ">>>>>>>>>>>> qkv_bias_ <<<<<<<<<<<<" << std::endl; + q_bias_.Print(os); + os << ">>>>>>>>>>>> dense_weight_ <<<<<<<<<<<<" << std::endl; + dense_weight_.Print(os); + os << ">>>>>>>>>>>> dense_bias_ <<<<<<<<<<<<" << std::endl; + dense_bias_.Print(os); + LOG_S(3) << os.str(); + } +} + +} // namespace layers +} // namespace turbo_transformers diff --git a/turbo_transformers/layers/multi_headed_attention.h b/turbo_transformers/layers/multi_headed_attention.h new file mode 100644 index 00000000..1e42e52d --- /dev/null +++ b/turbo_transformers/layers/multi_headed_attention.h @@ -0,0 +1,106 @@ +// Copyright (C) 2020 THL A29 Limited, a Tencent company. +// All rights reserved. +// Licensed under the BSD 3-Clause License (the "License"); you may +// not use this file except in compliance with the License. You may +// obtain a copy of the License at +// https://opensource.org/licenses/BSD-3-Clause +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" basis, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// See the AUTHORS file for names of contributors. + +#pragma once +#include +#include +#include + +#include +#include "turbo_transformers/core/tensor.h" + +namespace turbo_transformers { +namespace layers { + +class MultiHeadedAttention { + public: + MultiHeadedAttention(core::Tensor k_weight, core::Tensor k_bias, + core::Tensor v_weight, core::Tensor v_bias, + core::Tensor q_weight, core::Tensor q_bias, + core::Tensor dense_weight, core::Tensor dense_bias, + core::Tensor qkv_weight, core::Tensor qkv_bias, + int64_t num_attention_heads) + : k_weight_(std::move(k_weight)), //(768, 768) + k_bias_(std::move(k_bias)), + v_weight_(std::move(v_weight)), //(768, 768) + v_bias_(std::move(v_bias)), + q_weight_(std::move(q_weight)), //(768, 768) + q_bias_(std::move(q_bias)), + dense_weight_(std::move(dense_weight)), + dense_bias_(std::move(dense_bias)), + qkv_weight_(std::move(qkv_weight)), + qkv_bias_(std::move(qkv_bias)), + layernorm_gamma_(nullptr), + layernorm_beta_(nullptr), + num_attention_heads_(num_attention_heads) { + EnforceShapeAndType(); + } + + MultiHeadedAttention(core::Tensor k_weight, core::Tensor k_bias, + core::Tensor v_weight, core::Tensor v_bias, + core::Tensor q_weight, core::Tensor q_bias, + core::Tensor dense_weight, core::Tensor dense_bias, + core::Tensor qkv_weight, core::Tensor qkv_bias, + core::Tensor layernorm_gamma, + core::Tensor layernorm_beta, + int64_t num_attention_heads) + : k_weight_(std::move(k_weight)), //(768, 768) + k_bias_(std::move(k_bias)), + v_weight_(std::move(v_weight)), //(768, 768) + v_bias_(std::move(v_bias)), + q_weight_(std::move(q_weight)), //(768, 768) + q_bias_(std::move(q_bias)), + dense_weight_(std::move(dense_weight)), + dense_bias_(std::move(dense_bias)), + qkv_weight_(std::move(qkv_weight)), + qkv_bias_(std::move(qkv_bias)), + layernorm_gamma_(std::move(layernorm_gamma)), + layernorm_beta_(std::move(layernorm_beta)), + num_attention_heads_(num_attention_heads) { + EnforceShapeAndType(); + } + void EnforceShapeAndType() const; + + void operator()(const core::Tensor& key_tensor, + const core::Tensor& value_tensor, + const core::Tensor& query_tensor, + const core::Tensor& attention_mask, + const std::string& attn_type, core::Tensor* output, + core::Tensor* att_score, + std::unordered_map layer_cache, + bool pre_layernorm = false, bool post_layernorm = false, + bool post_add_input = false, + bool is_trans_weight = false) const; + + private: + core::Tensor k_weight_; + core::Tensor k_bias_; + core::Tensor v_weight_; + core::Tensor v_bias_; + core::Tensor q_weight_; + core::Tensor q_bias_; + core::Tensor dense_weight_; + core::Tensor dense_bias_; + + core::Tensor qkv_weight_; // store fused qkv weight/bias for "self" type + // attention computation. + core::Tensor qkv_bias_; + + core::Tensor layernorm_gamma_; + core::Tensor layernorm_beta_; + + int64_t num_attention_heads_; +}; + +} // namespace layers +} // namespace turbo_transformers diff --git a/turbo_transformers/layers/positionwise_ffn.cpp b/turbo_transformers/layers/positionwise_ffn.cpp new file mode 100644 index 00000000..fbec26bf --- /dev/null +++ b/turbo_transformers/layers/positionwise_ffn.cpp @@ -0,0 +1,80 @@ +// Copyright (C) 2020 THL A29 Limited, a Tencent company. +// All rights reserved. +// Licensed under the BSD 3-Clause License (the "License"); you may +// not use this file except in compliance with the License. You may +// obtain a copy of the License at +// https://opensource.org/licenses/BSD-3-Clause +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" basis, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// See the AUTHORS file for names of contributors. + +#include "turbo_transformers/layers/positionwise_ffn.h" + +#include + +#include "turbo_transformers/core/memory.h" +#include "turbo_transformers/layers/kernels/activation.h" +#include "turbo_transformers/layers/kernels/common.h" +#include "turbo_transformers/layers/kernels/layer_norm.h" +#include "turbo_transformers/layers/kernels/mat_mul.h" +#include "turbo_transformers/layers/kernels/utils.h" +#ifdef WITH_PERFTOOLS +#include "turbo_transformers/core/profiler.h" +#endif + +namespace turbo_transformers { +namespace layers { + +void PositionwiseFeedForward::operator()(const core::Tensor& input_tensor, + core::Tensor* output_tensor, + bool is_trans_weight) const { + auto d_ff = + is_trans_weight ? dense_weight_1_.shape(0) : dense_weight_1_.shape(1); + + auto model_dim_weight = + is_trans_weight ? dense_weight_1_.shape(1) : dense_weight_1_.shape(0); + auto model_dim = input_tensor.shape(2); + + TT_ENFORCE_EQ( + model_dim_weight, model_dim, + "dense weight and input tensor should have the same model_dim."); + + auto devType = input_tensor.device_type(); + auto devId = input_tensor.device_id(); + + // input tensor size (batch_size, input_len, model_dim) + auto batch_size = input_tensor.shape(0); + auto input_len = input_tensor.shape(1); + // allocate memory for temp data + core::Tensor input_tensor_copy(nullptr); + input_tensor_copy.Reshape({batch_size, input_len, model_dim}, devType, + devId); + core::Tensor temp_tensor(nullptr); + temp_tensor.Reshape({batch_size * input_len, d_ff}, devType, devId); + + // start computation + core::Copy(input_tensor, input_tensor_copy, "FFN/AddInputBias"); + + output_tensor->Reshape({batch_size, input_len, model_dim}, devType, + devId, "FFN/Reshape"); + kernels::LayerNorm(layer_norm_weight_, layer_norm_bias_, + &input_tensor_copy, 1e-12, "FFN/LayerNorm"); + kernels::MatMul(input_tensor_copy, false, dense_weight_1_, is_trans_weight, + 1.0, // input (b*seq, model) X dense_weight_1_ (model_dim, + // d_ff) -> temp_tensor (B*seq, d_ff) + &temp_tensor, 0.0, "FFN/gemm0"); + kernels::AddBiasAct( + dense_bias_1_, &temp_tensor, "FFN/AddBiasAct"); + kernels::MatMul(temp_tensor, false, dense_weight_2_, is_trans_weight, 1.0, + &input_tensor_copy, 0.0, "FFN/gemm1"); + kernels::AddInputBias(input_tensor, input_tensor_copy, dense_bias_2_, + output_tensor, "FFN/AddInputBias"); +} + +void PositionwiseFeedForward::EnforceShapeAndType() const {} + +} // namespace layers +} // namespace turbo_transformers diff --git a/turbo_transformers/layers/positionwise_ffn.h b/turbo_transformers/layers/positionwise_ffn.h new file mode 100644 index 00000000..9e091872 --- /dev/null +++ b/turbo_transformers/layers/positionwise_ffn.h @@ -0,0 +1,55 @@ +// Copyright (C) 2020 THL A29 Limited, a Tencent company. +// All rights reserved. +// Licensed under the BSD 3-Clause License (the "License"); you may +// not use this file except in compliance with the License. You may +// obtain a copy of the License at +// https://opensource.org/licenses/BSD-3-Clause +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" basis, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. +// See the AUTHORS file for names of contributors. + +#pragma once +#include +#include +#include "turbo_transformers/core/tensor.h" + +namespace turbo_transformers { +namespace layers { + +class PositionwiseFeedForward { + public: + PositionwiseFeedForward(core::Tensor dense_weight_1, + core::Tensor dense_bias_1, + core::Tensor dense_weight_2, + core::Tensor dense_bias_2, + core::Tensor layer_norm_weight, + core::Tensor layer_norm_bias) + : dense_weight_1_(std::move(dense_weight_1)), + dense_bias_1_(std::move(dense_bias_1)), + dense_weight_2_(std::move(dense_weight_2)), + dense_bias_2_(std::move(dense_bias_2)), + layer_norm_weight_(std::move(layer_norm_weight)), + layer_norm_bias_(std::move(layer_norm_bias)) { + EnforceShapeAndType(); + } + void EnforceShapeAndType() const; + + // according to profiling results on Intel 61xx, is_trans_weight = true is + // faster + void operator()(const core::Tensor &input_tensor, core::Tensor *output, + bool is_trans_weight = true) const; + + private: + core::Tensor dense_weight_1_; + core::Tensor dense_bias_1_; + core::Tensor dense_weight_2_; + core::Tensor dense_bias_2_; + core::Tensor layer_norm_weight_; + core::Tensor layer_norm_bias_; +}; + +} // namespace layers +} // namespace turbo_transformers diff --git a/turbo_transformers/layers/types.h b/turbo_transformers/layers/types.h index 182e275e..f957678b 100644 --- a/turbo_transformers/layers/types.h +++ b/turbo_transformers/layers/types.h @@ -17,7 +17,7 @@ namespace layers { namespace types { enum class ReduceType { kMax = 0, kSum }; -enum class ActivationType { Gelu = 0, Tanh }; +enum class ActivationType { Gelu = 0, Tanh = 1, Relu = 2 }; enum class PoolType { kMax = 0, kMean, kFirst, kLast }; } // namespace types } // namespace layers diff --git a/turbo_transformers/python/pybind.cpp b/turbo_transformers/python/pybind.cpp index 09bf1ee0..48bd651a 100644 --- a/turbo_transformers/python/pybind.cpp +++ b/turbo_transformers/python/pybind.cpp @@ -11,9 +11,11 @@ // permissions and limitations under the License. // See the AUTHORS file for names of contributors. +#include #include "absl/memory/memory.h" #include "loguru.hpp" #include "pybind11/pybind11.h" + #include "turbo_transformers/core/blas.h" #include "turbo_transformers/core/config.h" #include "turbo_transformers/core/profiler.h" @@ -23,6 +25,8 @@ #include "turbo_transformers/layers/bert_intermediate.h" #include "turbo_transformers/layers/bert_output.h" #include "turbo_transformers/layers/bert_pooler.h" +#include "turbo_transformers/layers/multi_headed_attention.h" +#include "turbo_transformers/layers/positionwise_ffn.h" #include "turbo_transformers/layers/prepare_bert_masks.h" #include "turbo_transformers/layers/sequence_pool.h" @@ -65,8 +69,8 @@ PYBIND11_MODULE(turbo_transformers_cxx, m) { m.def("set_stderr_verbose_level", [](int v) { loguru::g_stderr_verbosity = v; }); - m.def("enable_gperf", &core::EnableGperf); - m.def("disable_gperf", &core::DisableGperf); + m.def("enable_perf", &core::EnableGperf); + m.def("disable_perf", &core::DisableGperf); m.def("set_num_threads", &core::SetNumThreads); py::class_(m, "Tensor") @@ -112,6 +116,41 @@ PYBIND11_MODULE(turbo_transformers_cxx, m) { })) .def("__call__", &layers::BertAttention::operator()); + py::class_(m, "MultiHeadedAttention") + .def(py::init( + [](core::Tensor &key_weight, core::Tensor &key_bias, + core::Tensor &value_weight, core::Tensor &value_bias, + core::Tensor &query_weight, core::Tensor &query_bias, + core::Tensor &dense_weight, core::Tensor &dense_bias, + core::Tensor &qkv_weight, core::Tensor &qkv_bias, + int num_attention_heads) -> layers::MultiHeadedAttention * { + return new layers::MultiHeadedAttention( + std::move(key_weight), std::move(key_bias), + std::move(value_weight), std::move(value_bias), + std::move(query_weight), std::move(query_bias), + std::move(dense_weight), std::move(dense_bias), + std::move(qkv_weight), std::move(qkv_bias), + num_attention_heads); + })) + .def(py::init( + [](core::Tensor &key_weight, core::Tensor &key_bias, + core::Tensor &value_weight, core::Tensor &value_bias, + core::Tensor &query_weight, core::Tensor &query_bias, + core::Tensor &dense_weight, core::Tensor &dense_bias, + core::Tensor &qkv_weight, core::Tensor &qkv_bias, + core::Tensor &layernorm_gamma, core::Tensor &layernorm_beta, + int num_attention_heads) -> layers::MultiHeadedAttention * { + return new layers::MultiHeadedAttention( + std::move(key_weight), std::move(key_bias), + std::move(value_weight), std::move(value_bias), + std::move(query_weight), std::move(query_bias), + std::move(dense_weight), std::move(dense_bias), + std::move(qkv_weight), std::move(qkv_bias), + std::move(layernorm_gamma), std::move(layernorm_beta), + num_attention_heads); + })) + .def("__call__", &layers::MultiHeadedAttention::operator()); + py::class_(m, "BertIntermediate") .def(py::init([](core::Tensor &dense_weight, core::Tensor &dense_bias) -> layers::BertIntermediate * { @@ -147,6 +186,19 @@ PYBIND11_MODULE(turbo_transformers_cxx, m) { py::class_(m, "PrepareBertMasks") .def(py::init()) .def("__call__", &layers::PrepareBertMasks::operator()); + + py::class_(m, "PositionwiseFeedForward") + .def(py::init([](core::Tensor &dense_weight_1, core::Tensor &dense_bias_1, + core::Tensor &dense_weight_2, core::Tensor &dense_bias_2, + core::Tensor &layer_norm_weight, + core::Tensor &layer_norm_bias) + -> layers::PositionwiseFeedForward * { + return new layers::PositionwiseFeedForward( + std::move(dense_weight_1), std::move(dense_bias_1), + std::move(dense_weight_2), std::move(dense_bias_2), + std::move(layer_norm_weight), std::move(layer_norm_bias)); + })) + .def("__call__", &layers::PositionwiseFeedForward::operator()); } } // namespace python diff --git a/turbo_transformers/python/tests/bert_attention_test.py b/turbo_transformers/python/tests/bert_attention_test.py index 6a9a1ce2..64f4966b 100644 --- a/turbo_transformers/python/tests/bert_attention_test.py +++ b/turbo_transformers/python/tests/bert_attention_test.py @@ -31,7 +31,8 @@ def init_data(self, use_cuda): test_device = torch.device('cuda:0') if use_cuda else \ torch.device('cpu:0') if not use_cuda: - torch.set_num_threads(1) + torch.set_num_threads(4) + turbo_transformers.set_num_threads(4) torch.set_grad_enabled(False) cfg = BertConfig(attention_probs_dropout_prob=0.0, @@ -44,6 +45,10 @@ def init_data(self, use_cuda): # Get FT Attention turbo_attention = turbo_transformers.BertAttention.from_torch( torch_attention) + + turbo_decoder_attention = turbo_transformers.MultiHeadedAttention.from_torch( + torch_attention, is_trans_weight=False) + hidden_size = cfg.hidden_size input_tensor = torch.rand(size=(batch_size, seq_length, hidden_size), @@ -54,10 +59,10 @@ def init_data(self, use_cuda): device=test_device) attention_mask = attention_mask[:, None, None, :] attention_mask = (1.0 - attention_mask) * -10000.0 - return torch_attention, turbo_attention, input_tensor, attention_mask + return torch_attention, turbo_attention, turbo_decoder_attention, input_tensor, attention_mask - def check_torch_and_turbo(self, use_cuda, num_iter=2): - torch_attention, turbo_attention, input_tensor, attention_mask = \ + def check_torch_and_turbo(self, use_cuda, num_iter=1): + torch_attention, turbo_attention, turbo_decoder_attention, input_tensor, attention_mask = \ self.init_data(use_cuda) device = "GPU" if use_cuda else "CPU" torch_model = lambda: torch_attention(input_tensor, attention_mask) @@ -67,9 +72,9 @@ def check_torch_and_turbo(self, use_cuda, num_iter=2): f"BertAttention \"({batch_size},{seq_length:03})\" ", f"{device} Torch QPS, {torch_qps}, time, {torch_time_consume}") - turob_model = lambda: turbo_attention(input_tensor, attention_mask) - turbo_self_attention_result, turbo_qps, turbo_time_consume = \ - test_helper.run_model(turob_model, use_cuda, + turbo_model = lambda: turbo_attention(input_tensor, attention_mask) + turbo_attention_result, turbo_qps, turbo_time_consume = \ + test_helper.run_model(turbo_model, use_cuda, num_iter) print( f"BertAttention \"({batch_size},{seq_length:03})\" ", @@ -79,18 +84,43 @@ def check_torch_and_turbo(self, use_cuda, num_iter=2): self.assertTrue( torch.max( torch.abs(torch_attention_result[0] - - turbo_self_attention_result)) < 1e-3 - if use_cuda else 1e-4) + turbo_attention_result[0])) < ( + 1e-3 if use_cuda else 1e-4)) + + turbo_multiheaded_model = lambda: turbo_decoder_attention( + input_tensor, + input_tensor, + input_tensor, + attention_mask, + layer_cache=None, + attn_type="self", + pre_layernorm=False, + post_layernorm=True, + post_add_input=False, + is_trans_weight=False) + turbo_decoder_attn_result, turbo_decoder_qps, turbo_decoder_time_consume = \ + test_helper.run_model(turbo_multiheaded_model, use_cuda, + num_iter, use_profile=False) + print( + f"MultiHeadedAttention \"({batch_size},{seq_length:03})\" ", + f" {device} Turbo QPS, {turbo_decoder_qps}, time, {turbo_decoder_time_consume}" + ) + self.assertTrue( + torch.max( + torch.abs(torch_attention_result[0] - + turbo_decoder_attn_result[0])) < ( + 1e-3 if use_cuda else 1e-4)) + with open(fname, "a") as fh: fh.write( f"\"({batch_size},{seq_length:03})\", {torch_qps}, {turbo_qps}\n" ) def test_bert_attention(self): - self.check_torch_and_turbo(use_cuda=False) + self.check_torch_and_turbo(use_cuda=False, num_iter=1) if torch.cuda.is_available() and \ turbo_transformers.config.is_compiled_with_cuda(): - self.check_torch_and_turbo(use_cuda=True) + self.check_torch_and_turbo(use_cuda=True, num_iter=1) globals()[f"TestBertAtt{batch_size}_{seq_length:3}"] = TestBertAttention @@ -98,7 +128,7 @@ def test_bert_attention(self): with open(fname, "w") as fh: fh.write(", torch, turbo_transformers\n") for batch_size in [1, 2]: - for seq_length in [10, 16, 20, 24, 40, 48, 60, 64, 80, 100, 120, 128]: + for seq_length in [10, 20, 40, 60, 80, 100]: create_test(batch_size, seq_length) if __name__ == '__main__': diff --git a/turbo_transformers/python/tests/bert_layer_test.py b/turbo_transformers/python/tests/bert_layer_test.py index 7d2a2316..73844dbd 100644 --- a/turbo_transformers/python/tests/bert_layer_test.py +++ b/turbo_transformers/python/tests/bert_layer_test.py @@ -83,7 +83,7 @@ def check_torch_and_turbo(self, use_cuda): self.assertTrue( torch.max( torch.abs(torch_bert_layer_result[0] - - turbo_bert_layer_result)) < tolerate_error) + turbo_bert_layer_result[0])) < tolerate_error) with open(fname, "a") as fh: fh.write( f"\"({batch_size},{seq_length:03})\", {torch_qps}, {turbo_qps}\n" diff --git a/turbo_transformers/python/tests/bert_output_test.py b/turbo_transformers/python/tests/bert_output_test.py index d9d2c5db..f963c7d4 100644 --- a/turbo_transformers/python/tests/bert_output_test.py +++ b/turbo_transformers/python/tests/bert_output_test.py @@ -55,7 +55,6 @@ def init_data(self, use_cuda) -> None: def check_torch_and_turbo(self, use_cuda): self.init_data(use_cuda) - sio = io.StringIO() num_iter = 2 device = "GPU" if use_cuda else "CPU" @@ -63,28 +62,21 @@ def check_torch_and_turbo(self, use_cuda): self.attention_output) torch_result, torch_qps, torch_time = \ test_helper.run_model(torch_model, use_cuda, num_iter) - print(f'Bert Output Plain PyTorch({device}) QPS {torch_qps}', - file=sio) + print(f'Bert Output Plain PyTorch({device}) QPS {torch_qps}') turbo_model = lambda: self.turbo_bertout(self.intermediate_output, self.attention_output) turbo_result, turbo_qps, turbo_time = \ test_helper.run_model(turbo_model, use_cuda, num_iter) print( - f'Bert Output Plain TurboTransformer({device}) QPS {turbo_qps}', - file=sio) + f'Bert Output Plain TurboTransformer({device}) QPS {turbo_qps}' + ) # cuda version precision is lower due to tensor-core self.assertTrue( torch.max(torch.abs(torch_result - turbo_result)) < 1e-2 if use_cuda else 1e-4) - sio.seek(0) - with open(f"gpu_bert_output_qps_{batch_size}_{seq_length:03}.txt", - "w") as of: - for line in sio: - print(line.strip(), file=of) - def test_bertout(self): self.check_torch_and_turbo(use_cuda=False) if torch.cuda.is_available() and \ diff --git a/turbo_transformers/python/tests/decoder_multi_headed_attention_test.py b/turbo_transformers/python/tests/decoder_multi_headed_attention_test.py new file mode 100644 index 00000000..e0a2c091 --- /dev/null +++ b/turbo_transformers/python/tests/decoder_multi_headed_attention_test.py @@ -0,0 +1,339 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +import turbo_transformers + +import unittest +import sys +import torch +import os + +from onmt.modules.multi_headed_attn import MultiHeadedAttention +# from my_multi_headed_attn import MultiHeadedAttention + +sys.path.append(os.path.dirname(__file__)) +import test_helper + +fname = "tt_decoder_multi_headed_attention.txt" + + +def create_test(batch_size, + key_seq_len, + query_seq_len, + attn_type, + pre_layernorm, + post_add_input, + with_quantize_dynamic=False, + set_layer_cache=False): + class TestMultiHeadedAttention(unittest.TestCase): + def init_data(self, use_cuda): + self.test_device = torch.device('cuda:0') if use_cuda else \ + torch.device('cpu:0') + if not use_cuda: + torch.set_num_threads(4) + turbo_transformers.set_num_threads(4) + + torch.set_grad_enabled(False) + self.head_count = 16 + self.model_dim = 1024 #self.model_dim should % self.head_count = 0 + self.size_per_head = int(self.model_dim / self.head_count) + + onmt_multi_headed_attention = MultiHeadedAttention( + self.head_count, self.model_dim) + onmt_multi_headed_attention.eval() + torch_layernorm = torch.nn.LayerNorm(self.model_dim, eps=1e-6) + torch_layernorm.eval() + + if use_cuda: + onmt_multi_headed_attention.to(self.test_device) + torch_layernorm.to(self.test_device) + + K = torch.rand( + size=( + batch_size, + key_seq_len, #from_seq + self.model_dim), + dtype=torch.float32, + device=self.test_device) + V = torch.rand(size=(batch_size, key_seq_len, self.model_dim), + dtype=torch.float32, + device=self.test_device) + Q = torch.rand( + size=( + batch_size, + query_seq_len, #to_seq + self.model_dim), + dtype=torch.float32, + device=self.test_device) + + turbo_attn_trans = turbo_transformers.MultiHeadedAttention.from_onmt( + onmt_multi_headed_attention, + torch_layernorm, + is_trans_weight=True) + turbo_attn_notrans = turbo_transformers.MultiHeadedAttention.from_onmt( + onmt_multi_headed_attention, + torch_layernorm, + is_trans_weight=False) + + if with_quantize_dynamic and not use_cuda: + self.q_onmt_multi_headed_attention = torch.quantization.quantize_dynamic( + onmt_multi_headed_attention) + return onmt_multi_headed_attention, torch_layernorm, turbo_attn_trans, turbo_attn_notrans, Q, K, V + + def check_torch_and_turbo(self, use_cuda, num_iter=1): + onmt_multi_headed_attention, torch_layernorm, turbo_attn_trans, turbo_attn_notrans, Q, K, V = \ + self.init_data(use_cuda) + device = "GPU" if use_cuda else "CPU" + info = f"\"({device}, {set_layer_cache}, {pre_layernorm}, {post_add_input}, {attn_type}, {batch_size}, {key_seq_len:03}, {query_seq_len:03})\"" + + if attn_type == "context": + attention_mask = torch.zeros((batch_size, 1, key_seq_len), + dtype=torch.bool, + device=self.test_device) + elif attn_type == "self": + attention_mask = None + # torch.zeros( + # (batch_size, query_seq_len, key_seq_len), + # dtype=torch.bool, + # device=self.test_device) + else: + raise "attn type is not supported" + + # set layer_cache + if set_layer_cache: + memory_keys = torch.rand(size=(batch_size, self.head_count, + key_seq_len, + self.size_per_head), + dtype=torch.float32, + device=self.test_device) + memory_values = torch.rand(size=(batch_size, self.head_count, + key_seq_len, + self.size_per_head), + dtype=torch.float32, + device=self.test_device) + self_keys = torch.rand(size=(batch_size, self.head_count, + query_seq_len, + self.size_per_head), + dtype=torch.float32, + device=self.test_device) + self_values = torch.rand(size=(batch_size, self.head_count, + query_seq_len, + self.size_per_head), + dtype=torch.float32, + device=self.test_device) + print("self_keys size: ", self_keys.size()) + layer_cache_torch = { + "memory_keys": torch.clone(memory_keys), + "memory_values": torch.clone(memory_values), + "self_keys": torch.clone(self_keys), + "self_values": torch.clone(self_values) + } + else: + layer_cache_torch = { + "memory_keys": None, + "memory_values": None, + "self_keys": None, + "self_values": None + } + + onmt_model = lambda: onmt_multi_headed_attention( + K, + V, + torch.clone(torch_layernorm(Q)) if pre_layernorm else Q, + mask=attention_mask, + layer_cache=layer_cache_torch, + attn_type=attn_type) + + onmt_multi_headed_attention_result, torch_qps, torch_time_consume = \ + test_helper.run_model(onmt_model, use_cuda, num_iter) # return output, attns + + onmt_attns = onmt_multi_headed_attention_result[1] + if post_add_input: + onmt_output = onmt_multi_headed_attention_result[0] + Q + else: + onmt_output = onmt_multi_headed_attention_result[0] + print( + f"Multi Headed Attention {info} ONMT, QPS,{torch_qps}, time, {torch_time_consume}" + ) + + if with_quantize_dynamic and not use_cuda: + q_onmt_model = lambda: self.q_onmt_multi_headed_attention( + K, + V, + torch.clone(torch_layernorm(Q)) if pre_layernorm else Q, + mask=attention_mask, + layer_cache=layer_cache_torch, + attn_type=attn_type) + + q_onmt_multi_headed_attention_result, q_torch_qps, q_torch_time_consume = \ + test_helper.run_model(q_onmt_model, use_cuda, num_iter) # return output, attns + onmt_attns = q_onmt_multi_headed_attention_result[1] + if post_add_input: + onmt_output = q_onmt_multi_headed_attention_result[0] + Q + else: + onmt_output = q_onmt_multi_headed_attention_result[0] + + print( + f"Multi Headed Attention {info} Q-ONMT, QPS, {q_torch_qps}, time, {q_torch_time_consume}" + ) + + # benchmarking turbo with weight transposed + turbo_attention_mask = attention_mask.float( + ) * -1e18 if attention_mask is not None else None + + if set_layer_cache: + layer_cache_turbo = { + "memory_keys": torch.clone(memory_keys), + "memory_values": torch.clone(memory_values), + "self_keys": torch.clone(self_keys), + "self_values": torch.clone(self_values) + } + else: + layer_cache_turbo = { + "memory_keys": None, + "memory_values": None, + "self_keys": None, + "self_values": None + } + + turbo_model_trans = lambda: turbo_attn_trans( + K, + V, + Q, + turbo_attention_mask, + layer_cache=layer_cache_turbo, + attn_type=attn_type, + pre_layernorm=pre_layernorm, + post_add_input=post_add_input, + is_trans_weight=True) + + # with turbo_transformers.pref_guard("pref_test") as perf: + turbo_result, turbo_qps, turbo_time_consume = \ + test_helper.run_model(turbo_model_trans, use_cuda, + num_iter) + + turbo_output_trans, turbo_attns_trans = turbo_result + print( + f"Multi Headed Attention {info} Turbo Trans, QPS, {turbo_qps}, time, {turbo_time_consume}" + ) + self.assertTrue( + torch.max(torch.abs(onmt_output - turbo_output_trans)) < ( + 1e-3 if use_cuda else 1e-4)) + self.assertTrue( + torch.max(torch.abs(onmt_attns - turbo_attns_trans)) < ( + 1e-3 if use_cuda else 1e-4)) + + if layer_cache_torch is not None: + for k, v in layer_cache_torch.items(): + if v is not None: + self.assertTrue( + torch.max(torch.abs(layer_cache_turbo[k] - + v)) < 1e-3) + + # benchmarking turbo with weight not transposed + if set_layer_cache: + layer_cache_turbo = { + "memory_keys": torch.clone(memory_keys), + "memory_values": torch.clone(memory_values), + "self_keys": torch.clone(self_keys), + "self_values": torch.clone(self_values) + } + else: + layer_cache_turbo = { + "memory_keys": None, + "memory_values": None, + "self_keys": None, + "self_values": None + } + + turbo_model_notrans = lambda: turbo_attn_notrans( + K, + V, + Q, + turbo_attention_mask, + layer_cache=layer_cache_turbo, + attn_type=attn_type, + pre_layernorm=pre_layernorm, + post_add_input=post_add_input, + is_trans_weight=False) + + with turbo_transformers.pref_guard("pref_test") as perf: + turbo_result, turbo_qps, turbo_time_consume_notrans = \ + test_helper.run_model(turbo_model_notrans, use_cuda, + num_iter) + + turbo_output_notrans, turbo_attns_notrans = turbo_result + + print( + f"Multi Headed Attention {info} Turbo NoTrans, QPS,{turbo_qps}, time, {turbo_time_consume_notrans}" + ) + + self.assertTrue( + torch.max(torch.abs(onmt_output - turbo_output_notrans)) < ( + 1e-3 if use_cuda else 1e-4)) + self.assertTrue( + torch.max(torch.abs(onmt_attns - turbo_attns_notrans)) < ( + 1e-3 if use_cuda else 1e-4)) + + if with_quantize_dynamic and not use_cuda: + with open(fname, "a") as fh: + fh.write( + f"{info} {torch_qps}, {q_torch_qps}, {turbo_qps}\n") + else: + with open(fname, "a") as fh: + fh.write(f"{info} {torch_qps}, {turbo_qps}\n") + + def test_multi_headed_attention(self): + self.check_torch_and_turbo(use_cuda=False) + if torch.cuda.is_available() and \ + turbo_transformers.config.is_compiled_with_cuda(): + self.check_torch_and_turbo(use_cuda=True) + + globals( + )[f"TestMultiHeadedAttention{batch_size}_{key_seq_len:3}_{query_seq_len:3}_{attn_type}_{pre_layernorm}_{post_add_input}_{with_quantize_dynamic}_{set_layer_cache}"] = TestMultiHeadedAttention + + +with open(fname, "w") as fh: + fh.write(", torch, q_torch, turbo_transformers\n") + +for set_layer_cache in [True, False]: + for post_add_input in [False]: + for pre_layernorm in [False]: + for batch_size in [4]: + for query_seq_len in [1, 2]: + create_test(batch_size, + query_seq_len, + query_seq_len, + "self", + pre_layernorm, + post_add_input, + with_quantize_dynamic=False, + set_layer_cache=set_layer_cache) + +for set_layer_cache in [False, True]: + for post_add_input in [False]: + for pre_layernorm in [False]: + for batch_size in [4]: + for key_seq_len in [10, 20, 30, 40, 50]: + for query_seq_len in [1, 2]: + create_test(batch_size, + key_seq_len, + query_seq_len, + "context", + pre_layernorm, + post_add_input, + with_quantize_dynamic=False, + set_layer_cache=set_layer_cache) + +if __name__ == '__main__': + unittest.main() diff --git a/turbo_transformers/python/tests/decoder_transformer_decoder_layer_test.py b/turbo_transformers/python/tests/decoder_transformer_decoder_layer_test.py new file mode 100644 index 00000000..6096ba44 --- /dev/null +++ b/turbo_transformers/python/tests/decoder_transformer_decoder_layer_test.py @@ -0,0 +1,185 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. +import turbo_transformers + +import unittest +import sys +import torch +import os + +from onmt.decoders.transformer import TransformerDecoderLayer + +sys.path.append(os.path.dirname(__file__)) +import test_helper + +fname = "tt_decoder_transformer_decoder_layer.txt" + + +def create_test(batch_size, src_length, T, with_quantize_dynamic=False): + class TestDecoder(unittest.TestCase): + def init_data(self, use_cuda): + self.test_device = torch.device('cuda:0') if use_cuda else \ + torch.device('cpu:0') + if not use_cuda: + torch.set_num_threads(4) + turbo_transformers.set_num_threads(4) + + torch.set_grad_enabled(False) + self.model_dim = 1024 + self.onmt_decoder = TransformerDecoderLayer(d_model=self.model_dim, + heads=8, + d_ff=1024, + dropout=0., + attention_dropout=0.) + self.onmt_decoder.eval() + if use_cuda: + self.onmt_decoder.to(self.test_device) + self.turbo_decoder = turbo_transformers.TransformerDecoderLayer.from_onmt( + self.onmt_decoder) + + # https://pytorch.org/docs/stable/quantization.html + if with_quantize_dynamic and not use_cuda: + self.quantized_onmt_decoder = torch.quantization.quantize_dynamic( + self.onmt_decoder) + + def check_torch_and_turbo(self, use_cuda, num_iter=1): + deivce_type = "GPU" if use_cuda else "CPU" + info = f"\"({deivce_type}, {batch_size}, {src_length}, {T})\"" + + step = 2 + self.init_data(use_cuda=use_cuda) + + self.inputs = torch.rand(batch_size, + T, + self.model_dim, + dtype=torch.float32, + device=self.test_device) + self.memory_bank = torch.rand(batch_size, + src_length, + self.model_dim, + dtype=torch.float32, + device=self.test_device) + + self.src_pad_mask = torch.zeros(batch_size, + 1, + src_length, + dtype=torch.float32, + device=self.test_device).bool() + self.tgt_pad_mask = torch.zeros(batch_size, + 1, + T, + dtype=torch.float32, + device=self.test_device).bool() + + onmt_model = lambda: self.onmt_decoder(self.inputs, + self.memory_bank, + self.src_pad_mask, + self.tgt_pad_mask, + layer_cache=None, + step=step, + future=False) + + onmt_result, torch_qps, torch_time_consume = \ + test_helper.run_model(onmt_model, use_cuda, num_iter) + + onmt_mid, attns, attn_align = onmt_result + + print( + f"ONMT Deocder {info} ", + f"{deivce_type} QPS, {torch_qps}, time, {torch_time_consume}") + + if with_quantize_dynamic and not use_cuda: + quantized_onmt_model = lambda: self.quantized_onmt_decoder( + self.inputs, + self.memory_bank, + self.src_pad_mask, + self.tgt_pad_mask, + layer_cache=None, + step=step, + future=False) + + quantized_onmt_result, quantized_torch_qps, quantized_torch_time_consume = \ + test_helper.run_model(quantized_onmt_model, use_cuda, num_iter) + + quantized_onmt_mid, quantized_attns, quantized_attn_align = quantized_onmt_result + + print( + f"ONMT Quantized Deocder {info} ", + f"{deivce_type} QPS, {quantized_torch_qps}, time, {quantized_torch_time_consume}" + ) + + # print(onmt_mid) + # print(quantized_onmt_mid) + + # self.assertTrue( + # torch.max(torch.abs(onmt_mid - + # quantized_onmt_mid)) < (1e-3 if use_cuda else 1e-4)) + # self.assertTrue( + # torch.max(torch.abs(attns - quantized_attns)) < ( + # 1e-3 if use_cuda else 1e-4)) + + turbo_model = lambda: self.turbo_decoder(self.inputs, + self.memory_bank, + self.src_pad_mask, + self.tgt_pad_mask, + layer_cache=None, + step=step, + future=False) + + with turbo_transformers.pref_guard(info) as perf: + turbo_result, turbo_qps, turbo_time_consume = \ + test_helper.run_model(turbo_model, use_cuda, num_iter) + + turbo_mid, turbo_attns, _ = turbo_result + + print( + f"Turbo Deocder {info} ", + f"{deivce_type} QPS, {turbo_qps}, time, {turbo_time_consume}") + + self.assertTrue( + torch.max(torch.abs(onmt_mid - + turbo_mid)) < (1e-3 if use_cuda else 1e-4)) + self.assertTrue( + torch.max(torch.abs(attns - turbo_attns)) < ( + 1e-3 if use_cuda else 1e-4)) + + if with_quantize_dynamic and not use_cuda: + with open(fname, "a") as fh: + fh.write( + f"{info} {torch_qps}, {quantized_torch_qps}, {turbo_qps}\n" + ) + else: + with open(fname, "a") as fh: + fh.write(f"{info} {torch_qps}, {turbo_qps}\n") + + def test_decoder(self): + self.check_torch_and_turbo(use_cuda=False) + if torch.cuda.is_available() and \ + turbo_transformers.config.is_compiled_with_cuda(): + self.check_torch_and_turbo(use_cuda=True) + + globals( + )[f"TestDecoder{batch_size}_{src_length}_{T}_{with_quantize_dynamic}"] = TestDecoder + + +with open(fname, "w") as fh: + fh.write(", torch, *q_torch, turbo_transformers\n") + +for quantize in [True]: + for batch_size in [4]: + for src_length in [10, 20, 40, 60, 80, 100]: + for T in range(10, src_length, 10): + create_test(batch_size, src_length, T, quantize) + +if __name__ == '__main__': + unittest.main() diff --git a/turbo_transformers/python/tests/positionwise_ffn_test.py b/turbo_transformers/python/tests/positionwise_ffn_test.py new file mode 100644 index 00000000..d88fb2d1 --- /dev/null +++ b/turbo_transformers/python/tests/positionwise_ffn_test.py @@ -0,0 +1,117 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +import turbo_transformers + +import unittest +import sys +import torch +import os + +from onmt.modules.position_ffn import PositionwiseFeedForward + +sys.path.append(os.path.dirname(__file__)) +import test_helper + +fname = "ffn.txt" + + +def create_test(batch_size, input_len): + class TestPositionwiseFeedForward(unittest.TestCase): + def init_data(self, use_cuda): + self.test_device = torch.device('cuda:0') if use_cuda else \ + torch.device('cpu:0') + if not use_cuda: + torch.set_num_threads(4) + turbo_transformers.set_num_threads(4) + + self.model_dim = 1024 + self.d_ff = 4096 + + torch.set_grad_enabled(False) + onmt_ffn = PositionwiseFeedForward(self.model_dim, self.d_ff) + onmt_ffn.eval() + if use_cuda: + onmt_ffn.to(self.test_device) + + turbo_ffn_trans = turbo_transformers.PositionwiseFeedForward.from_onmt( + onmt_ffn, is_trans_weight=True) + turbo_ffn_notrans = turbo_transformers.PositionwiseFeedForward.from_onmt( + onmt_ffn, is_trans_weight=False) + # (batch_size, input_len, model_dim) + inputs = torch.rand(size=(batch_size, input_len, self.model_dim), + dtype=torch.float32, + device=self.test_device) + return onmt_ffn, turbo_ffn_trans, turbo_ffn_notrans, inputs + + def check_torch_and_turbo(self, use_cuda, num_iter=1): + onmt_ffn, turbo_ffn_trans, turbo_ffn_notrans, inputs = self.init_data( + use_cuda) + device = "GPU" if use_cuda else "CPU" + onmt_model = lambda: onmt_ffn(inputs) + onmt_model_result, torch_qps, torch_time_consume = \ + test_helper.run_model(onmt_model, use_cuda, num_iter) + + print( + f"PositionwiseFeedForward \"({batch_size}, {input_len:03})\" ", + f"{device} ONMT QPS, {torch_qps}, time, {torch_time_consume}") + + turbo_model_trans = lambda: turbo_ffn_trans(inputs, + is_trans_weight=True) + with turbo_transformers.pref_guard("gpref_test") as perf: + turbo_model_result, turbo_qps_trans, turbo_time_consume_trans = \ + test_helper.run_model(turbo_model_trans, use_cuda, num_iter) + + print( + f"PositionwiseFeedForward \"({batch_size}, {input_len:03})\" ", + f"{device} Turbo Trans QPS, {turbo_qps_trans}, time, {turbo_time_consume_trans}" + ) + + turbo_model_notrans = lambda: turbo_ffn_notrans( + inputs, is_trans_weight=False) + with turbo_transformers.pref_guard("gpref_test") as perf: + turbo_model_result, turbo_qps_notrans, turbo_time_consume_notrans = \ + test_helper.run_model(turbo_model_notrans, use_cuda, num_iter) + + print( + f"PositionwiseFeedForward Notrans \"({batch_size}, {input_len:03})\" ", + f"{device} Turbo NoTrans QPS, {turbo_qps_notrans}, time, {turbo_time_consume_notrans}" + ) + self.assertTrue( + torch.max(torch.abs(turbo_model_result - onmt_model_result)) < + (1e-3 if use_cuda else 1e-4)) + + with open(fname, "a") as fh: + fh.write( + f"\"({batch_size},{input_len:03})\", {torch_qps}, {turbo_qps_trans}, {turbo_qps_notrans}\n" + ) + + def test_positionwise_feed_forward(self): + self.check_torch_and_turbo(use_cuda=False) + if torch.cuda.is_available() and \ + turbo_transformers.config.is_compiled_with_cuda(): + self.check_torch_and_turbo(use_cuda=True) + + globals( + )[f"TestPositionwiseFeedForward{batch_size}_{input_len:3}"] = TestPositionwiseFeedForward + + +with open(fname, "w") as fh: + fh.write(", torch, turbo_trans, turbo_notrans\n") + +for batch_size in [4]: + for input_len in [10, 20, 30, 40, 50]: + create_test(batch_size, input_len) + +if __name__ == '__main__': + unittest.main() diff --git a/turbo_transformers/python/tests/test_helper.py b/turbo_transformers/python/tests/test_helper.py index 2d7f8bb2..5075025f 100644 --- a/turbo_transformers/python/tests/test_helper.py +++ b/turbo_transformers/python/tests/test_helper.py @@ -15,8 +15,12 @@ import torch.jit import torch.onnx +import cProfile +import cProfile, pstats, io +from pstats import SortKey -def run_model(model, use_cuda, num_iter=50): + +def run_model(model, use_cuda, num_iter=50, use_profile=False): # warm up model() if use_cuda: @@ -38,3 +42,15 @@ def run_model(model, use_cuda, num_iter=50): qps = num_iter / t.elapsed time_consume = t.elapsed / num_iter return result, qps, time_consume + + +# for debug +def show_tensor(T, info): + if T is None: + print(info, " None") + return + T = torch.clone(T) + print(info, T.size()) + print(T.flatten()[0:10]) + print(T.flatten()[-10:]) + print(torch.sum(T.flatten())) diff --git a/turbo_transformers/python/turbo_transformers/layers/__init__.py b/turbo_transformers/python/turbo_transformers/layers/__init__.py index 8f93dd31..aae6e97e 100644 --- a/turbo_transformers/python/turbo_transformers/layers/__init__.py +++ b/turbo_transformers/python/turbo_transformers/layers/__init__.py @@ -13,19 +13,13 @@ from .modeling_bert import BertEmbeddings, BertIntermediate, BertOutput, BertAttention, BertLayer, SequencePool, \ BertEncoder, BertModel, PoolingType, BertPooler, BertModelWithPooler +from .modeling_decoder import MultiHeadedAttention, PositionwiseFeedForward, TransformerDecoderLayer, TransformerDecoder from .return_type import ReturnType __all__ = [ - 'BertEmbeddings', - 'BertIntermediate', - 'BertOutput', - 'BertAttention', - 'BertLayer', - 'BertEncoder', - 'BertModel', - 'ReturnType', - 'BertPooler', - 'SequencePool', - 'PoolingType', - 'BertModelWithPooler', + 'BertEmbeddings', 'BertIntermediate', 'BertOutput', 'BertAttention', + 'BertLayer', 'BertEncoder', 'BertModel', 'ReturnType', 'BertPooler', + 'SequencePool', 'PoolingType', 'BertModelWithPooler', + 'MultiHeadedAttention', 'PositionwiseFeedForward', + 'PositionwiseFeedForward', 'TransformerDecoderLayer', 'TransformerDecoder' ] diff --git a/turbo_transformers/python/turbo_transformers/layers/modeling_bert.py b/turbo_transformers/python/turbo_transformers/layers/modeling_bert.py index 22f13d68..d90cece6 100644 --- a/turbo_transformers/python/turbo_transformers/layers/modeling_bert.py +++ b/turbo_transformers/python/turbo_transformers/layers/modeling_bert.py @@ -19,7 +19,7 @@ from typing import Union, Optional, Sequence import torch from .return_type import convert_returns_as_type, ReturnType -import torch.utils.dlpack as dlpack +from .utils import try_convert, convert2tt_tensor, to_param_dict_convert_tt, to_param_dict, create_empty_if_none, AnyTensor from transformers.modeling_bert import BertEmbeddings as TorchBertEmbeddings from transformers.modeling_bert import BertIntermediate as TorchBertIntermediate @@ -40,37 +40,6 @@ ] -def _try_convert(t): - if isinstance(t, torch.Tensor): - return convert2tt_tensor(t) - elif isinstance(t, np.ndarray): - return convert2tt_tensor(torch.from_numpy(t)) - else: - return t - - -def convert2tt_tensor(t): - return cxx.Tensor.from_dlpack(dlpack.to_dlpack(t)) - - -def _to_param_dict(torch_module: torch.nn.Module): - return { - k: convert2tt_tensor(v) - for k, v in torch_module.named_parameters() - } - - -def _to_param_dict_naive(torch_module: torch.nn.Module): - return {k: v for k, v in torch_module.named_parameters()} - - -def _create_empty_if_none(output): - return output if output is not None else cxx.Tensor.create_empty() - - -AnyTensor = Union[cxx.Tensor, torch.Tensor] - - class BertEmbeddings(cxx.BERTEmbedding): def __call__(self, input_ids: AnyTensor, @@ -78,17 +47,17 @@ def __call__(self, token_type_ids: AnyTensor, return_type: Optional[ReturnType] = None, output: Optional[cxx.Tensor] = None): - input_ids = _try_convert(input_ids) - position_ids = _try_convert(position_ids) - token_type_ids = _try_convert(token_type_ids) - output = _create_empty_if_none(output) + input_ids = try_convert(input_ids) + position_ids = try_convert(position_ids) + token_type_ids = try_convert(token_type_ids) + output = create_empty_if_none(output) super(BertEmbeddings, self).__call__(input_ids, position_ids, token_type_ids, output) return convert_returns_as_type(output, return_type) @staticmethod def from_torch(bert_embedding: TorchBertEmbeddings) -> 'BertEmbeddings': - params = _to_param_dict(bert_embedding) + params = to_param_dict_convert_tt(bert_embedding) return BertEmbeddings(params['word_embeddings.weight'], params['position_embeddings.weight'], params['token_type_embeddings.weight'], @@ -99,11 +68,11 @@ def from_torch(bert_embedding: TorchBertEmbeddings) -> 'BertEmbeddings': def from_npz(file_name: str): f = np.load(file_name) return BertEmbeddings( - _try_convert(f['embeddings.word_embeddings.weight']), - _try_convert(f['embeddings.position_embeddings.weight']), - _try_convert(f['embeddings.token_type_embeddings.weight']), - _try_convert(f['embeddings.LayerNorm.weight']), - _try_convert(f['embeddings.LayerNorm.bias'])) + try_convert(f['embeddings.word_embeddings.weight']), + try_convert(f['embeddings.position_embeddings.weight']), + try_convert(f['embeddings.token_type_embeddings.weight']), + try_convert(f['embeddings.LayerNorm.weight']), + try_convert(f['embeddings.LayerNorm.bias'])) class BertIntermediate(cxx.BertIntermediate): @@ -111,15 +80,16 @@ def __call__(self, input_tensor: AnyTensor, return_type: Optional[ReturnType] = None, output: Optional[cxx.Tensor] = None): - input_tensor = _try_convert(input_tensor) - output = _create_empty_if_none(output) + input_tensor = try_convert(input_tensor) + output = create_empty_if_none(output) super(BertIntermediate, self).__call__(input_tensor, output) return convert_returns_as_type(output, return_type) @staticmethod def from_torch(intermediate: TorchBertIntermediate): - intermediate_params = _to_param_dict_naive(intermediate) - weight = torch.clone(torch.t(intermediate_params["dense.weight"])) + intermediate_params = to_param_dict(intermediate) + weight = torch.clone( + torch.t(intermediate_params["dense.weight"]).contiguous()) return BertIntermediate( convert2tt_tensor(weight), convert2tt_tensor(intermediate_params['dense.bias'])) @@ -128,9 +98,9 @@ def from_torch(intermediate: TorchBertIntermediate): def from_npz(file_name: str, layer_num: int): f = np.load(file_name) return BertIntermediate( - _try_convert( + try_convert( f[f'encoder.layer.{layer_num}.intermediate.dense.weight']), - _try_convert( + try_convert( f[f'encoder.layer.{layer_num}.intermediate.dense.bias'])) @@ -140,18 +110,18 @@ def __call__(self, attention_output: AnyTensor, return_type: Optional[ReturnType] = None, output: Optional[cxx.Tensor] = None): - intermediate_output = _try_convert(intermediate_output) - attention_output = _try_convert(attention_output) - output = _create_empty_if_none(output) + intermediate_output = try_convert(intermediate_output) + attention_output = try_convert(attention_output) + output = create_empty_if_none(output) super(BertOutput, self).__call__(intermediate_output, attention_output, output) return convert_returns_as_type(output, return_type) @staticmethod def from_torch(output: TorchBertOutput): - params = _to_param_dict_naive(output) - weight = convert2tt_tensor(torch.clone(torch.t( - params["dense.weight"]))) + params = to_param_dict(output) + weight = convert2tt_tensor( + torch.clone(torch.t(params["dense.weight"]).contiguous())) return BertOutput(weight, convert2tt_tensor(params["dense.bias"]), convert2tt_tensor(params["LayerNorm.weight"]), convert2tt_tensor(params["LayerNorm.bias"])) @@ -160,12 +130,11 @@ def from_torch(output: TorchBertOutput): def from_npz(file_name: str, layer_num: int): f = np.load(file_name) return BertOutput( - _try_convert(f[f'encoder.layer.{layer_num}.output.dense.weight']), - _try_convert(f[f'encoder.layer.{layer_num}.output.dense.bias']), - _try_convert( + try_convert(f[f'encoder.layer.{layer_num}.output.dense.weight']), + try_convert(f[f'encoder.layer.{layer_num}.output.dense.bias']), + try_convert( f[f'encoder.layer.{layer_num}.output.LayerNorm.weight']), - _try_convert( - f[f'encoder.layer.{layer_num}.output.LayerNorm.bias'])) + try_convert(f[f'encoder.layer.{layer_num}.output.LayerNorm.bias'])) class BertAttention(cxx.BertAttention): @@ -173,30 +142,42 @@ def __call__(self, input_tensor: AnyTensor, attention_mask: AnyTensor, return_type: Optional[ReturnType] = None, - output: Optional[cxx.Tensor] = None): - input_tensor = _try_convert(input_tensor) - attention_mask = _try_convert(attention_mask) - output = _create_empty_if_none(output) - super(BertAttention, self).__call__(input_tensor, attention_mask, - output) - return convert_returns_as_type(output, return_type) + output: Optional[cxx.Tensor] = None, + is_trans_weight: Optional[cxx.Tensor] = False): + """ + implement BertSelfAttention in + https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bert.py#L183 + self.output_attentions always true + return (context_layer, attention_probs) + """ + input_tensor = try_convert(input_tensor) + attention_mask = try_convert(attention_mask) + output = create_empty_if_none(output) + attn_probs = cxx.Tensor.create_empty() + super(BertAttention, + self).__call__(input_tensor, attention_mask, output, attn_probs, + is_trans_weight) + return convert_returns_as_type(output, + return_type), convert_returns_as_type( + attn_probs, return_type) @staticmethod def from_torch(attention: TorchBertAttention): params = {k: v for k, v in attention.named_parameters()} - with torch.no_grad(): # merge self.query.weight, self.query.weight and self.query.weight together as qkv.weight qkv_weight = torch.clone( torch.t( torch.cat((params['self.query.weight'], params['self.key.weight'], - params['self.value.weight']), 0))) + params['self.value.weight']), + 0).contiguous()).contiguous()) qkv_bias = torch.cat( (params['self.query.bias'], params['self.key.bias'], - params['self.value.bias']), 0) + params['self.value.bias']), 0).contiguous() - output_weight = torch.clone(torch.t(params['output.dense.weight'])) + output_weight = torch.clone( + torch.t(params['output.dense.weight']).contiguous()) att = BertAttention( convert2tt_tensor(qkv_weight), convert2tt_tensor(qkv_bias), convert2tt_tensor(output_weight), @@ -211,16 +192,16 @@ def from_torch(attention: TorchBertAttention): def from_npz(file_name: str, layer_num: int, num_attention_heads: int): f = np.load(file_name) return BertAttention( - _try_convert(f[f'encoder.layer.{layer_num}.attention.qkv.weight']), - _try_convert(f[f'encoder.layer.{layer_num}.attention.qkv.bias']), - _try_convert( + try_convert(f[f'encoder.layer.{layer_num}.attention.qkv.weight']), + try_convert(f[f'encoder.layer.{layer_num}.attention.qkv.bias']), + try_convert( f[f'encoder.layer.{layer_num}.attention.output.dense.weight']), - _try_convert( + try_convert( f[f'encoder.layer.{layer_num}.attention.output.dense.bias']), - _try_convert(f[ + try_convert(f[ f'encoder.layer.{layer_num}.attention.output.LayerNorm.weight'] - ), - _try_convert( + ), + try_convert( f[f'encoder.layer.{layer_num}.attention.output.LayerNorm.bias'] ), num_attention_heads) @@ -239,7 +220,7 @@ def __call__(self, attention_output: Optional[cxx.Tensor] = None, intermediate_output: Optional[cxx.Tensor] = None, output: Optional[cxx.Tensor] = None): - attention_output = self.attention( + attention_output, attn = self.attention( hidden_states, attention_mask, return_type=ReturnType.turbo_transformers, @@ -251,7 +232,8 @@ def __call__(self, return self.output(intermediate_output, attention_output, return_type=return_type, - output=output) + output=output), convert_returns_as_type( + attn, return_type) @staticmethod def from_torch(layer: TorchBertLayer): @@ -279,9 +261,9 @@ def __call__(self, attention_output: Optional[cxx.Tensor] = None, intermediate_output: Optional[cxx.Tensor] = None, output: Optional[cxx.Tensor] = None): - attention_output = _create_empty_if_none(attention_output) - intermediate_output = _create_empty_if_none(intermediate_output) - output = _create_empty_if_none(output) + attention_output = create_empty_if_none(attention_output) + intermediate_output = create_empty_if_none(intermediate_output) + output = create_empty_if_none(output) first = True for l in self.layer: if first: @@ -290,12 +272,12 @@ def __call__(self, else: input_states = output - output = l(hidden_states=input_states, - attention_mask=attention_mask, - return_type=ReturnType.turbo_transformers, - attention_output=attention_output, - intermediate_output=intermediate_output, - output=output) + output, _ = l(hidden_states=input_states, + attention_mask=attention_mask, + return_type=ReturnType.turbo_transformers, + attention_output=attention_output, + intermediate_output=intermediate_output, + output=output) return convert_returns_as_type(output, return_type) @staticmethod @@ -319,8 +301,8 @@ def __call__(self, input_tensor: AnyTensor, return_type: Optional[ReturnType] = None, output_tensor: Optional[cxx.Tensor] = None): - input_tensor = _try_convert(input_tensor) - output_tensor = _create_empty_if_none(output_tensor) + input_tensor = try_convert(input_tensor) + output_tensor = create_empty_if_none(output_tensor) super(SequencePool, self).__call__(input_tensor, output_tensor) return convert_returns_as_type(output_tensor, return_type) @@ -355,13 +337,13 @@ def __call__(self, hidden_cache: Optional[AnyTensor] = None, output: Optional[AnyTensor] = None, return_type: Optional[ReturnType] = None): - attention_masks = _try_convert(_create_empty_if_none(attention_masks)) - token_type_ids = _try_convert(_create_empty_if_none(token_type_ids)) - position_ids = _try_convert(_create_empty_if_none(position_ids)) - inputs = _try_convert(inputs) + attention_masks = try_convert(create_empty_if_none(attention_masks)) + token_type_ids = try_convert(create_empty_if_none(token_type_ids)) + position_ids = try_convert(create_empty_if_none(position_ids)) + inputs = try_convert(inputs) extended_attention_masks = cxx.Tensor.create_empty() - output = _create_empty_if_none(output) - hidden_cache = _create_empty_if_none(hidden_cache) + output = create_empty_if_none(output) + hidden_cache = create_empty_if_none(hidden_cache) self.prepare(inputs, attention_masks, token_type_ids, position_ids, extended_attention_masks) @@ -417,23 +399,24 @@ def __call__(self, input_tensor: AnyTensor, return_type: Optional[ReturnType] = None, output: Optional[cxx.Tensor] = None): - input_tensor = _try_convert(input_tensor) - output = _create_empty_if_none(output) + input_tensor = try_convert(input_tensor) + output = create_empty_if_none(output) super(BertPooler, self).__call__(input_tensor, output) return convert_returns_as_type(output, return_type) @staticmethod def from_torch(pooler: TorchBertPooler): - pooler_params = _to_param_dict_naive(pooler) - weight = torch.clone(torch.t(pooler_params['dense.weight'])) + pooler_params = to_param_dict(pooler) + weight = torch.clone( + torch.t(pooler_params['dense.weight']).contiguous()) return BertPooler(convert2tt_tensor(weight), convert2tt_tensor(pooler_params['dense.bias'])) @staticmethod def from_npz(file_name: str, device: Optional[torch.device] = None): f = np.load(file_name) - return BertPooler(_try_convert(f['pooler.dense.weight']), - _try_convert(f['pooler.dense.bias'])) + return BertPooler(try_convert(f['pooler.dense.weight']), + try_convert(f['pooler.dense.bias'])) class BertModelWithPooler: diff --git a/turbo_transformers/python/turbo_transformers/layers/modeling_decoder.py b/turbo_transformers/python/turbo_transformers/layers/modeling_decoder.py new file mode 100644 index 00000000..d66e3029 --- /dev/null +++ b/turbo_transformers/python/turbo_transformers/layers/modeling_decoder.py @@ -0,0 +1,633 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +try: + # `turbo_transformers_cxxd` is the name on debug mode + import turbo_transformers.turbo_transformers_cxxd as cxx +except ImportError: + import turbo_transformers.turbo_transformers_cxx as cxx +from typing import Union, Optional, Sequence +import torch +from .return_type import convert_returns_as_type, ReturnType + +from .utils import try_convert, convert2tt_tensor, create_empty_if_none, AnyTensor + +from onmt.modules.multi_headed_attn import MultiHeadedAttention as OnmtMultiHeadedAttention +from transformers.modeling_bert import BertAttention as TorchBertAttention + +from onmt.modules.position_ffn import PositionwiseFeedForward as OnmtPositionwiseFeedForward +from onmt.decoders.transformer import TransformerDecoderLayer as OnmtTransformerDecoderLayer +from onmt.decoders.transformer import TransformerDecoder as OnmtTransformerDecoder +from onmt.modules import Embeddings as TorchBertEmbeddings + +from torch.nn import LayerNorm as TorchLayerNorm +from onmt.utils.misc import sequence_mask + +import enum +import numpy as np + +__all__ = [ + 'MultiHeadedAttention', 'PositionwiseFeedForward', + 'TransformerDecoderLayer', 'TransformerDecoder' +] + + +class MultiHeadedAttention(cxx.MultiHeadedAttention): + def __call__(self, + key_tensor: AnyTensor, + value_tensor: AnyTensor, + query_tensor: AnyTensor, + mask: Optional[AnyTensor] = None, + layer_cache: Optional[dict] = None, + attn_type: str = None, + pre_layernorm: bool = False, + post_layernorm: bool = False, + post_add_input: bool = False, + is_trans_weight: bool = False, + return_type: Optional[ReturnType] = None, + output: Optional[cxx.Tensor] = None, + attn: Optional[cxx.Tensor] = None): + """ Implement a MultiHeadedAttention of OpenNMT-py + https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/multi_headed_attn.py + + Attention: Now layer_cache only contains Nones + For self-dot Attention elements in dict `layer_cache` are Nones. + https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/decoders/transformer.py#L339 + """ + key_tensor = try_convert(key_tensor) + value_tensor = try_convert(value_tensor) + query_tensor = try_convert(query_tensor) + + mask = try_convert(create_empty_if_none(mask)) + + output = create_empty_if_none(output) + attn = create_empty_if_none(attn) + layer_cache_tmp = {} + if layer_cache is not None: + for k, v in layer_cache.items(): + if v is not None: + layer_cache_tmp[k] = try_convert(v) + else: + layer_cache_tmp[k] = create_empty_if_none(v) + + super(MultiHeadedAttention, + self).__call__(key_tensor, value_tensor, query_tensor, mask, + attn_type, output, attn, layer_cache_tmp, + pre_layernorm, post_layernorm, post_add_input, + is_trans_weight) + + if layer_cache is not None: + for k, v in layer_cache_tmp.items(): + if "memory" in k and "context" in attn_type or "self" in k and "self" in attn_type: + layer_cache[k] = convert_returns_as_type( + v, ReturnType.TORCH) + + return convert_returns_as_type(output, + return_type), convert_returns_as_type( + attn, return_type) + + @staticmethod + def pack_parameter(multi_headed_attn: OnmtMultiHeadedAttention, + is_trans_weight: Optional[bool] = False): + # linear_keys.weight + # linear_keys.bias + # linear_values.weight + # linear_values.bias + # linear_query.weight + # linear_query.bias + # final_linear.weight + # final_linear.bias + attn_params = {k: v for k, v in multi_headed_attn.named_parameters()} + if multi_headed_attn.max_relative_positions != 0: + raise "multi_headed_attn's max_relative_positions should be 0!" + + # merge self.query.weight, self.query.weight and self.query.weight together as qkv.weight + if is_trans_weight: + qkv_weight = torch.cat((attn_params['linear_query.weight'], + attn_params['linear_keys.weight'], + attn_params['linear_values.weight']), 0) + k_w = convert2tt_tensor(attn_params['linear_keys.weight']) + v_w = convert2tt_tensor(attn_params['linear_values.weight']) + q_w = convert2tt_tensor(attn_params['linear_query.weight']) + f_w = convert2tt_tensor(attn_params['final_linear.weight']) + else: + qkv_weight = torch.clone( + torch.t( + torch.cat((attn_params['linear_query.weight'], + attn_params['linear_keys.weight'], + attn_params['linear_values.weight']), + 0).contiguous()).contiguous()) + k_w = convert2tt_tensor( + torch.clone( + torch.t(attn_params['linear_keys.weight']).contiguous())) + v_w = convert2tt_tensor( + torch.clone( + torch.t(attn_params['linear_values.weight']).contiguous())) + q_w = convert2tt_tensor( + torch.clone( + torch.t(attn_params['linear_query.weight']).contiguous())) + f_w = convert2tt_tensor( + torch.clone( + torch.t(attn_params['final_linear.weight']).contiguous())) + + qkv_bias = torch.cat( + (attn_params['linear_query.bias'], attn_params['linear_keys.bias'], + attn_params['linear_values.bias']), 0) + return (k_w, convert2tt_tensor(attn_params['linear_keys.bias']), v_w, + convert2tt_tensor(attn_params['linear_values.bias']), q_w, + convert2tt_tensor(attn_params['linear_query.bias']), f_w, + convert2tt_tensor(attn_params['final_linear.bias']), + convert2tt_tensor(qkv_weight), convert2tt_tensor(qkv_bias)) + + @staticmethod + def from_onmt(multi_headed_attn: OnmtMultiHeadedAttention, + is_trans_weight: bool = False): + attn_params = {k: v for k, v in multi_headed_attn.named_parameters()} + if multi_headed_attn.max_relative_positions != 0: + raise "multi_headed_attn's max_relative_positions should be 0!" + + with torch.no_grad(): + att = MultiHeadedAttention( + *(MultiHeadedAttention.pack_parameter(attn_params, + is_trans_weight)), + multi_headed_attn.head_count) + return att + + @staticmethod + def from_onmt(multi_headed_attn: OnmtMultiHeadedAttention, + layer_norm: TorchLayerNorm, + is_trans_weight: bool = False): + ln_params = {k: v for k, v in layer_norm.named_parameters()} + attn_params = {k: v for k, v in multi_headed_attn.named_parameters()} + with torch.no_grad(): + att = MultiHeadedAttention( + *(MultiHeadedAttention.pack_parameter(multi_headed_attn, + is_trans_weight)), + convert2tt_tensor(ln_params['weight']), + convert2tt_tensor(ln_params['bias']), + multi_headed_attn.head_count) + return att + + @staticmethod + def from_torch(attention: TorchBertAttention, + layer_norm: Optional[TorchLayerNorm] = None, + is_trans_weight: bool = False): + """ + load an attn model from huggingface bert attention model. + """ + ln_params = {} + if layer_norm is not None: + ln_params = {k: v for k, v in layer_norm.named_parameters()} + params = {k: v for k, v in attention.named_parameters()} + with torch.no_grad(): + if is_trans_weight: + # merge self.query.weight, self.query.weight and self.query.weight together as qkv.weight + qkv_weight = torch.cat( + (params['self.query.weight'], params['self.key.weight'], + params['self.value.weight']), 0) + output_weight = params['output.dense.weight'] + k_w = params['self.key.weight'] + v_w = params['self.value.weight'] + q_w = params['self.query.weight'] + else: + # merge self.query.weight, self.query.weight and self.query.weight together as qkv.weight + qkv_weight = torch.clone( + torch.t( + torch.cat((params['self.query.weight'], + params['self.key.weight'], + params['self.value.weight']), + 0).contiguous()).contiguous()) + output_weight = torch.clone( + torch.t(params['output.dense.weight']).contiguous()) + k_w = torch.clone( + torch.t(params['self.key.weight']).contiguous()) + v_w = torch.clone( + torch.t(params['self.value.weight']).contiguous()) + q_w = torch.clone( + torch.t(params['self.query.weight']).contiguous()) + + qkv_bias = torch.cat( + (params['self.query.bias'], params['self.key.bias'], + params['self.value.bias']), 0) + + if layer_norm is not None: + att = MultiHeadedAttention( + convert2tt_tensor(k_w), + convert2tt_tensor(params['self.key.bias']), + convert2tt_tensor(v_w), + convert2tt_tensor(params['self.value.bias']), + convert2tt_tensor(q_w), + convert2tt_tensor(params['self.query.bias']), + convert2tt_tensor(output_weight), + convert2tt_tensor(params['output.dense.bias']), + convert2tt_tensor(qkv_weight), convert2tt_tensor(qkv_bias), + convert2tt_tensor(params['output.LayerNorm.weight']), + convert2tt_tensor(params['output.LayerNorm.bias']), + convert2tt_tensor(ln_params['weight']), + convert2tt_tensor(ln_params['bias']), + attention.self.num_attention_heads) + else: + att = MultiHeadedAttention( + convert2tt_tensor(k_w), + convert2tt_tensor(params['self.key.bias']), + convert2tt_tensor(v_w), + convert2tt_tensor(params['self.value.bias']), + convert2tt_tensor(q_w), + convert2tt_tensor(params['self.query.bias']), + convert2tt_tensor(output_weight), + convert2tt_tensor(params['output.dense.bias']), + convert2tt_tensor(qkv_weight), convert2tt_tensor(qkv_bias), + convert2tt_tensor(params['output.LayerNorm.weight']), + convert2tt_tensor(params['output.LayerNorm.bias']), + attention.self.num_attention_heads) + return att + + @staticmethod + def from_npz(file_name: str, layer_num: int, num_attention_heads: int): + f = np.load(file_name) + return BertAttention( + create_empty_if_none(None), create_empty_if_none(None), + create_empty_if_none(None), create_empty_if_none(None), + create_empty_if_none(None), create_empty_if_none(None), + try_convert( + f[f'encoder.layer.{layer_num}.attention.output.dense.weight']), + try_convert( + f[f'encoder.layer.{layer_num}.attention.output.dense.bias']), + try_convert(f[f'encoder.layer.{layer_num}.attention.qkv.weight']), + try_convert(f[f'encoder.layer.{layer_num}.attention.qkv.bias']), + try_convert(f[ + f'encoder.layer.{layer_num}.attention.output.LayerNorm.weight'] + ), + try_convert( + f[f'encoder.layer.{layer_num}.attention.output.LayerNorm.bias'] + ), num_attention_heads) + + +class PositionwiseFeedForward(cxx.PositionwiseFeedForward): + def __call__( + self, + input_tensor: AnyTensor, + return_type: Optional[ReturnType] = None, + is_trans_weight: Optional[bool] = True, #Intel 61xx True is faster + output: Optional[cxx.Tensor] = None): + input_tensor = try_convert(input_tensor) + output = create_empty_if_none(output) + super(PositionwiseFeedForward, self).__call__(input_tensor, output, + is_trans_weight) + return convert_returns_as_type(output, return_type) + + @staticmethod + def from_onmt(position_wise_ffn: OnmtPositionwiseFeedForward, + is_trans_weight: Optional[bool] = True): + params = {k: v for k, v in position_wise_ffn.named_parameters()} + # w_1.weight + # w_1.bias + # w_2.weight + # w_2.bias + # layer_norm.weight + # layer_norm.bias + + # Note that torch's weights of linear layer is transposed + if is_trans_weight: + w_1 = convert2tt_tensor(params['w_1.weight']) + w_2 = convert2tt_tensor(params['w_2.weight']) + else: + w_1 = convert2tt_tensor( + torch.clone(torch.t(params['w_1.weight']).contiguous())) + w_2 = convert2tt_tensor( + torch.clone(torch.t(params['w_2.weight']).contiguous())) + + with torch.no_grad(): + ffn = PositionwiseFeedForward( + w_1, convert2tt_tensor(params['w_1.bias']), w_2, + convert2tt_tensor(params['w_2.bias']), + convert2tt_tensor(params['layer_norm.weight']), + convert2tt_tensor(params['layer_norm.bias'])) + return ffn + + +class TransformerDecoderLayer: + def __init__(self, self_attn: MultiHeadedAttention, + context_attn: MultiHeadedAttention, + feed_forward: PositionwiseFeedForward): + """ Implement class TransformerDecoderLayer(nn.Module): + https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/decoders/transformer.py + self_attn_type of MultiHeadedAttention should always scaled-dot + """ + self.self_attn = self_attn + if not isinstance(self_attn, MultiHeadedAttention): + raise "self_attn should be of type MultiHeadedAttention" + self.context_attn = context_attn + if not isinstance(context_attn, MultiHeadedAttention): + raise "context_attn should be of type MultiHeadedAttention" + self.feed_forward = feed_forward + + def __call__(self, + input_tensor: torch.Tensor, + memory_bank: torch.Tensor, + src_pad_mask: torch.Tensor, + tgt_pad_mask: torch.Tensor, + layer_cache: Optional[dict] = None, + step: Optional[int] = None, + future: Optional[bool] = False, + with_align: Optional[bool] = False, + return_type: Optional[ReturnType] = None, + output: Optional[cxx.Tensor] = None): + """ Implement _forward method of class TransformerDecoderLayer + https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/decoders/transformer.py + Because we now do not need context aligment, so we do not provide a forward method + Args: + input_tensor (FloatTensor): ``(batch_size, T, model_dim)`` + memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)`` + src_pad_mask (bool): ``(batch_size, 1, src_len)`` + tgt_pad_mask (bool): ``(batch_size, 1, T)`` + layer_cache (dict or None): cached layer info when stepwise decode + step (int or None): stepwise decoding counter + future (bool): If set True, do not apply future_mask. + Returns: + (FloatTensor, FloatTensor): + * output ``(batch_size, T, model_dim)`` + * top_attns ``(batch_size, T, src_len)`` or None + * attn_align None + """ + # dec_mask = None which is no mask + dec_mask = None + + input_tensor = try_convert(input_tensor) + memory_bank = try_convert(memory_bank) + src_pad_mask = src_pad_mask.float() * -1e18 + src_pad_mask = try_convert(src_pad_mask) + + if step is None: + tgt_len = tgt_pad_mask.size(-1) + if not future: # apply future_mask, result mask in (B, T, T) + future_mask = torch.ones([tgt_len, tgt_len], + device=tgt_pad_mask.device, + dtype=torch.float32) + future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) + # BoolTensor was introduced in pytorch 1.2 + # try: + # future_mask = future_mask.bool() + # except AttributeError: + # pass + dec_mask = torch.gt(tgt_pad_mask + future_mask, 0).float() + else: # only mask padding, result mask in (B, 1, T) + dec_mask = tgt_pad_mask + + if dec_mask is None: + dec_mask = create_empty_if_none(dec_mask) + else: + dec_mask = dec_mask * -1e18 + dec_mask = try_convert(dec_mask) + + query, _ = self.self_attn(input_tensor, + input_tensor, + input_tensor, + mask=dec_mask, + layer_cache=layer_cache, + attn_type="self", + pre_layernorm=True, + post_add_input=True, + return_type=ReturnType.turbo_transformers) + + mid, attns = self.context_attn( + memory_bank, + memory_bank, + query, + mask=src_pad_mask, + layer_cache=layer_cache, + attn_type="context", + pre_layernorm=True, + post_add_input=True, + return_type=ReturnType.turbo_transformers) + + output = self.feed_forward(mid, return_type=return_type) + return output, convert_returns_as_type( + attns, return_type)[:, 0, :, :].contiguous( + ), None #attn_aligned mast be None + + @staticmethod + def from_onmt(transformer_decoder_layer: OnmtTransformerDecoderLayer): + params = { + k: v + for k, v in transformer_decoder_layer.named_parameters() + } + # for k, v in transformer_decoder_layer.named_parameters(): + # print(k, v.size()) + + # 12: self_attn.linear_keys.weight torch.Size([1024, 1024]) + # 12: self_attn.linear_keys.bias torch.Size([1024]) + # 12: self_attn.linear_values.weight torch.Size([1024, 1024]) + # 12: self_attn.linear_values.bias torch.Size([1024]) + # 12: self_attn.linear_query.weight torch.Size([1024, 1024]) + # 12: self_attn.linear_query.bias torch.Size([1024]) + # 12: self_attn.final_linear.weight torch.Size([1024, 1024]) + # 12: self_attn.final_linear.bias torch.Size([1024]) + # 12: context_attn.linear_keys.weight torch.Size([1024, 1024]) + # 12: context_attn.linear_keys.bias torch.Size([1024]) + # 12: context_attn.linear_values.weight torch.Size([1024, 1024]) + # 12: context_attn.linear_values.bias torch.Size([1024]) + # 12: context_attn.linear_query.weight torch.Size([1024, 1024]) + # 12: context_attn.linear_query.bias torch.Size([1024]) + # 12: context_attn.final_linear.weight torch.Size([1024, 1024]) + # 12: context_attn.final_linear.bias torch.Size([1024]) + # 12: feed_forward.w_1.weight torch.Size([1, 1024]) + # 12: feed_forward.w_1.bias torch.Size([1]) + # 12: feed_forward.w_2.weight torch.Size([1024, 1]) + # 12: feed_forward.w_2.bias torch.Size([1024]) + # 12: feed_forward.layer_norm.weight torch.Size([1024]) + # 12: feed_forward.layer_norm.bias torch.Size([1024]) + # 12: layer_norm_1.weight torch.Size([1024]) + # 12: layer_norm_1.bias torch.Size([1024]) + # 12: layer_norm_2.weight torch.Size([1024]) + # 12: layer_norm_2.bias torch.Size([1024]) + # 12: w_1.weight torch.Size([1, 1024]) + # 12: w_1.bias torch.Size([1]) + # 12: w_2.weight torch.Size([1024, 1]) + # 12: w_2.bias torch.Size([1024]) + # 12: layer_norm.weight torch.Size([1024]) + # 12: layer_norm.bias torch.Size([1024]) + + self_attn = MultiHeadedAttention.from_onmt( + transformer_decoder_layer.self_attn, + transformer_decoder_layer.layer_norm_1) + context_attn = MultiHeadedAttention.from_onmt( + transformer_decoder_layer.context_attn, + transformer_decoder_layer.layer_norm_2) + feed_forward = PositionwiseFeedForward.from_onmt( + transformer_decoder_layer.feed_forward) + + return TransformerDecoderLayer(self_attn, context_attn, feed_forward) + + +class TransformerDecoder: + """The Transformer decoder from "Attention is All You Need". + :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` + .. mermaid:: + graph BT + A[input] + B[multi-head self-attn] + BB[multi-head src-attn] + C[feed forward] + O[output] + A --> B + B --> BB + BB --> C + C --> O + Args: + num_layers (int): number of encoder layers. + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + copy_attn (bool): if using a separate copy attention + self_attn_type (str): type of self-attention scaled-dot, average + dropout (float): dropout in residual, self-attn(dot) and feed-forward + attention_dropout (float): dropout in context_attn (and self-attn(avg)) + embeddings (onmt.modules.Embeddings): + embeddings to use, should have positional encodings + max_relative_positions (int): + Max distance between inputs in relative positions representations, TODO(jiaruifang) only support 0 + aan_useffn (bool): Turn on the FFN layer in the AAN decoder, TODO(jiaruifang) only support False + full_context_alignment (bool): + whether enable an extra full context decoder forward for alignment + alignment_layer (int): N° Layer to supervise with for alignment guiding + alignment_heads (int): + N. of cross attention heads to use for alignment guiding + """ + def __init__(self, + embeddings: TorchBertEmbeddings, + transformer_layers: Sequence[TransformerDecoderLayer], + layer_norm: TorchLayerNorm, + copy_attn: Optional[bool] = False, + alignment_layer: Optional[int] = 0): + self.embeddings = embeddings + + # Decoder State + self.state = {} + + self.transformer_layers = transformer_layers + + # previously, there was a GlobalAttention module here for copy + # attention. But it was never actually used -- the "copy" attention + # just reuses the context attention. + self._copy = copy_attn #bool + self.layer_norm = layer_norm + + self.alignment_layer = alignment_layer + + def init_state(self, src, memory_bank, enc_hidden): + """Initialize decoder state.""" + self.state["src"] = src + self.state["cache"] = None + + def map_state(self, fn): + def _recursive_map(struct, batch_dim=0): + for k, v in struct.items(): + if v is not None: + if isinstance(v, dict): + _recursive_map(v) + else: + struct[k] = fn(v, batch_dim) + + self.state["src"] = fn(self.state["src"], 1) + if self.state["cache"] is not None: + _recursive_map(self.state["cache"]) + + def detach_state(self): + self.state["src"] = self.state["src"].detach() + + def __call__(self, + tgt: torch.Tensor, + memory_bank: torch.Tensor, + step: Optional[int] = None, + **kwargs): + """Decode, possibly stepwise.""" + if step == 0: + self._init_cache(memory_bank) + + tgt_words = tgt[:, :, 0].transpose(0, 1) + + emb = self.embeddings(tgt, step=step) + assert emb.dim() == 3 # len x batch x embedding_dim + + output = emb.transpose(0, 1).contiguous() + src_memory_bank = memory_bank.transpose(0, 1).contiguous() + + pad_idx = self.embeddings.word_padding_idx + src_lens = kwargs["memory_lengths"] + src_max_len = self.state["src"].shape[0] + #Turbo add bool -> float + src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) + tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] + + with_align = kwargs.pop('with_align', False) + if with_align: + raise "with_align must be False" + attn_aligns = [] + + # It's Turbo's show time! + for i, layer in enumerate(self.transformer_layers): + layer_cache = self.state["cache"]["layer_{}".format(i)] \ + if step is not None else None + output, attn, attn_align = layer(output, + src_memory_bank, + src_pad_mask, + tgt_pad_mask, + layer_cache=layer_cache, + step=step, + with_align=with_align) + if attn_align is not None: + attn_aligns.append(attn_align) + + # Turbo finished. + output = self.layer_norm(output) + dec_outs = output.transpose(0, 1).contiguous() + attn = attn.transpose(0, 1).contiguous() + + attns = {"std": attn} + if self._copy: + attns["copy"] = attn + if with_align: + attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` + # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg + + # TODO(OpenNMT-py) change the way attns is returned dict => list or tuple (onnx) + + return dec_outs, attns + + def _init_cache(self, memory_bank): + self.state["cache"] = {} + batch_size = memory_bank.size(1) + depth = memory_bank.size(-1) + + for i, layer in enumerate(self.transformer_layers): + layer_cache = {"memory_keys": None, "memory_values": None} + if not isinstance(layer.self_attn, MultiHeadedAttention): + raise "MultiHeadedAttention only not supported" + else: + layer_cache["self_keys"] = None + layer_cache["self_values"] = None + self.state["cache"]["layer_{}".format(i)] = layer_cache + + @staticmethod + def from_onmt(model: OnmtTransformerDecoder, + device: Optional[torch.device] = None): + if device is not None and 'cuda' in device.type and torch.cuda.is_available( + ): + model.to(device) + layers = [ + TransformerDecoderLayer.from_onmt(transformer_layer) + for transformer_layer in model.transformer_layers + ] + return TransformerDecoder(model.embeddings, layers, model.layer_norm, + model._copy, model.alignment_layer) diff --git a/turbo_transformers/python/turbo_transformers/layers/return_type.py b/turbo_transformers/python/turbo_transformers/layers/return_type.py index 4f1fc861..ffe32803 100644 --- a/turbo_transformers/python/turbo_transformers/layers/return_type.py +++ b/turbo_transformers/python/turbo_transformers/layers/return_type.py @@ -12,6 +12,7 @@ # See the AUTHORS file for names of contributors. import enum +import torch import torch.utils.dlpack as dlpack try: # `turbo_transformers_cxxd` is the name on debug mode @@ -19,7 +20,6 @@ except ImportError: import turbo_transformers.turbo_transformers_cxx as cxx from typing import Optional, Union -import torch __all__ = ['ReturnType', 'convert_returns_as_type'] diff --git a/turbo_transformers/python/turbo_transformers/layers/utils.py b/turbo_transformers/python/turbo_transformers/layers/utils.py new file mode 100644 index 00000000..beae6a4d --- /dev/null +++ b/turbo_transformers/python/turbo_transformers/layers/utils.py @@ -0,0 +1,60 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +import torch +import torch.utils.dlpack as dlpack +from typing import Union +import numpy as np + +try: + # `turbo_transformers_cxxd` is the name on debug mode + import turbo_transformers.turbo_transformers_cxxd as cxx +except ImportError: + import turbo_transformers.turbo_transformers_cxx as cxx +from .return_type import convert_returns_as_type, ReturnType + +__all__ = [ + 'try_convert', 'convert2tt_tensor', 'to_param_dict_convert_tt', + 'to_param_dict', 'create_empty_if_none', 'AnyTensor' +] + + +def convert2tt_tensor(t): + return cxx.Tensor.from_dlpack(dlpack.to_dlpack(t)) + + +def try_convert(t): + if isinstance(t, torch.Tensor): + return convert2tt_tensor(t) + elif isinstance(t, np.ndarray): + return convert2tt_tensor(torch.from_numpy(t)) + else: + return t + + +def to_param_dict_convert_tt(torch_module: torch.nn.Module): + return { + k: convert2tt_tensor(v) + for k, v in torch_module.named_parameters() + } + + +def to_param_dict(torch_module: torch.nn.Module): + return {k: v for k, v in torch_module.named_parameters()} + + +def create_empty_if_none(output): + return output if output is not None else cxx.Tensor.create_empty() + + +AnyTensor = Union[cxx.Tensor, torch.Tensor] diff --git a/turbo_transformers/python/turbo_transformers/utils.py b/turbo_transformers/python/turbo_transformers/utils.py index 1c41ca0f..780257b3 100644 --- a/turbo_transformers/python/turbo_transformers/utils.py +++ b/turbo_transformers/python/turbo_transformers/utils.py @@ -18,13 +18,20 @@ import turbo_transformers.turbo_transformers_cxx as cxx import contextlib -__all__ = ['gperf_guard', 'set_num_threads'] +__all__ = [ + 'pref_guard', 'set_num_threads', 'set_stderr_verbose_level', + 'disable_perf', 'enable_perf' +] set_num_threads = cxx.set_num_threads +set_stderr_verbose_level = cxx.set_stderr_verbose_level + +disable_perf = cxx.disable_perf +enable_perf = cxx.enable_perf @contextlib.contextmanager -def gperf_guard(filename: str): - cxx.enable_gperf(filename) +def pref_guard(filename: str): + cxx.enable_perf(filename) yield - cxx.disable_gperf() + cxx.disable_perf()