diff --git a/.github/workflows/dockerimage.yml b/.github/workflows/dockerimage.yml
index 0ba5adb3..be1ce169 100644
--- a/.github/workflows/dockerimage.yml
+++ b/.github/workflows/dockerimage.yml
@@ -5,10 +5,13 @@ on:
branches:
- master
- develop
+ - feature/decoder
pull_request:
branches:
- master
- develop
+ - feature/decoder
+
jobs:
build:
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e5f7877b..4b17e50e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -21,9 +21,9 @@ set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_FLAGS "-Wall")
set(CMAKE_C_FLAGS "-Wall")
-set(TURBO_TRANSFORMERS_VERSION 0.2.1)
+set(TURBO_TRANSFORMERS_VERSION 0.3.0)
-option(WITH_PROFILER "Compile with gperftools" OFF)
+option(WITH_PROFILER "Compile with profiler" OFF)
option(WITH_GPU "Build with GPU" OFF)
option(WITH_MODULE_BENCHMAKR "Catch2 unitest with benchmarking" ON)
@@ -65,9 +65,7 @@ endif ()
if (WITH_PROFILER)
- find_package(Gperftools REQUIRED)
- include_directories(${GPERFTOOLS_INCLUDE_DIR})
- add_definitions(-DWITH_GPERFTOOLS)
+ add_definitions(-DWITH_PERFTOOLS)
endif ()
IF (UNIX AND NOT APPLE)
diff --git a/Dockerfile_ci b/Dockerfile_ci
index 31b08585..b3304a32 100644
--- a/Dockerfile_ci
+++ b/Dockerfile_ci
@@ -1,14 +1,15 @@
-FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
+FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04
RUN apt-get update && \
- apt-get install -y curl git wget bzip2 build-essential ninja-build g++ && rm -rf /var/lib/apt/lists/*
+ apt-get install -y curl git wget bzip2 build-essential ninja-build g++ gfortran && rm -rf /var/lib/apt/lists/*
ENV PATH=/opt/miniconda3/bin:${PATH} CONDA_PREFIX=/opt/miniconda3
RUN curl -LO http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3-latest-Linux-x86_64.sh -p /opt/miniconda3 -b && \
rm Miniconda3-latest-Linux-x86_64.sh && \
conda update -y conda && \
- conda install pytorch==1.4.0 cudatoolkit=10.0 && \
+ conda install pytorch==1.5.0 cudatoolkit=10.0 && \
+ pip install OpenNMT-py && \
conda install curl conda-verify conda-build mkl-include cmake -c anaconda && \
conda install git git-lfs docopt -c conda-forge && \
conda clean -afy
diff --git a/README.md b/README.md
index cbad1020..f542566d 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,14 @@
## turbo_transformers: a fast and user-friendly tool for transformer inference on CPU and GPU
![logo](./images/logo.jpeg)
-### **make transformers serving fast by adding a turbo to your inference engine!**
+
**make transformers serving fast by adding a turbo to your inference engine!**
+### Background
Transformer is the most critical alogrithm innovation in the NLP field in recent years. It brings higher model accuracy while introduces more calculations. The efficient deployment of online Transformer-based services faces enormous challenges. In order to make the costly Transformer online service more efficient, the WeChat AI open-sourced a Transformer inference acceleration tool called TurboTransformers, which has the following characteristics.
-1. Excellent CPU / GPU performance. For Intel multi-core CPU and NVIDIA GPU hardware platforms, TurboTransformers can fully utilize all levels of computing power of the hardware. It has achieved better performance over pytorch / tensorflow and current mainstream optimization engines (such as onnxruntime-mkldnn / onnxruntime-gpu, torch JIT, NVIDIA faster transformers) on a variety of CPU and GPU hardware. See the detailed benchmark results below.
-2. Tailored to the characteristics of NLP inference tasks. Unlike the CV task, the input dimensions of the NLP inference task always change. The traditional approach is zero padding or truncation to a fixed length, which introduces additional zero padding computational overhead. Besides, some frameworks such as onnxruntime, tensorRT, and torchlib need to preprocess the compuatation-graph according to the input size in advance for the best performance, which is not suitable for NLP tasks with varying sizes. TurboTransformers can support variable-length input sequence processing without preprocessing.
-3. A simpler method of use. TurboTransformers supports python and C ++ interface for calling. It can be used as an acceleration plug-in for pytorch. In the Transformer task, the end-to-end acceleration effect obtained by adding a few lines of python code.
+1. Supporting both Transformers Encoder and Decoder.
+2. Excellent CPU / GPU performance. For Intel multi-core CPU and NVIDIA GPU hardware platforms, TurboTransformers can fully utilize all levels of computing power of the hardware. It has achieved better performance over pytorch / tensorflow and current mainstream optimization engines (such as onnxruntime-mkldnn / onnxruntime-gpu, torch JIT, NVIDIA faster transformers) on a variety of CPU and GPU hardware. See the detailed benchmark results below.
+3. Tailored to the characteristics of NLP inference tasks. Unlike the CV task, the input dimensions of the NLP inference task always change. The traditional approach is zero padding or truncation to a fixed length, which introduces additional zero padding computational overhead. Besides, some frameworks such as onnxruntime, tensorRT, and torchlib need to preprocess the compuatation-graph according to the input size in advance for the best performance, which is not suitable for NLP tasks with varying sizes. TurboTransformers can support variable-length input sequence processing without preprocessing.
+4. A simpler method of use. TurboTransformers supports python and C++ interface for calling. It can be used as an acceleration plug-in for pytorch. In the Transformer task, the end-to-end acceleration effect obtained by adding a few lines of python code.
TurboTransformers has been applied to multiple online BERT service scenarios in Tencent. For example, It brings 1.88x acceleration to the WeChat FAQ service, 2.11x acceleration to the public cloud sentiment analysis service, and 13.6x acceleration to the QQ recommendation system.
@@ -45,7 +47,7 @@ sh tools/build_and_run_unittests.sh $PWD -DWITH_GPU=OFF
# set(BLAS_PROVIDER "mkl" CACHE STRING "Set the blas provider library, in [openblas, mkl, blis]")
```
-Method 2:I do not want to unitest
+Method 2: I do not want to unitest
```
cd /workspace
mkdir -p build && cd build
@@ -67,7 +69,7 @@ sh tool/build_conda_package.sh
*We also prepared a docker image containing CPU version of TurboTransformers, as well as other related works, i.e. onnxrt v1.2.0 and pytorch-jit on dockerhub*
```
-docker pull thufeifeibear/turbo_transformers:0.2.0-release-cpu-dev
+docker pull thufeifeibear/turbo_transformers:0.3.0-cpu-dev
```
### Installation on GPU
```
@@ -77,8 +79,8 @@ git clone https://github.com/Tencent/TurboTransformers --recursive
```
# You can modify the environment variables in the script to specify the cuda version and operating system version
sh tools/build_docker_gpu.sh $PWD
-docker run --gpus all --net=host --rm -it -v $PWD:/workspace -v /etc/passwd:/etc/passwd --name=your_container_name REPOSITORY:TAG
-# for example: docker run --gpus all --net=host --rm -it -v $PWD:/workspace -v /etc/passwd:/etc/passwd --name=jiarui_gpu_env ccr.ccs.tencentyun.com/mmspr/turbo_transformers:0.1.1-cuda9.0-ubuntu16.04-gpu-dev
+nvidia-docker run --gpus all --net=host --rm -it -v $PWD:/workspace -v /etc/passwd:/etc/passwd --name=your_container_name REPOSITORY:TAG
+# for example: nvidia-docker run --gpus all --net=host --rm -it -v $PWD:/workspace -v /etc/passwd:/etc/passwd --name=jiarui_gpu_env ccr.ccs.tencentyun.com/mmspr/turbo_transformers:0.1.1-cuda9.0-ubuntu16.04-gpu-dev
```
2. Install pip package in docker and unitest test
@@ -92,68 +94,42 @@ sh tools/build_and_run_unittests.sh $PWD -DWITH_GPU=ON
cd benchmark
bash gpu_run_benchmark.sh
```
-*We also prepared a docker image containing GPU version of TurboTransformers.
+We also prepared a docker image containing GPU version of TurboTransformers.
```
-docker pull thufeifeibear/turbo_transformers:0.2.0-cuda10.0-cudnn7-devel-ubuntu18.04-gpu-release
+docker pull thufeifeibear/turbo_transformers:0.3.0-cuda10.0-cudnn7-devel-ubuntu18.04-gpu-dev
```
### Usage
-turbo_transformers provides C ++ / python API interfaces. we hope to do our best to adapt to a variety of online environments to reduce the difficulty of development for users.
+TurboTransformers provides C++ / python API interfaces. We hope to do our best to adapt to a variety of online environments to reduce the difficulty of development for users.
The first step in using turbo is to load a pre-trained model. We provide a way to load pytorch and tensorflow pre-trained models in [huggingface/transformers](https://github.com/huggingface).
The specific conversion method is to use the corresponding script in ./tools to convert the pre-trained model into an npz format file, and turbo uses the C ++ or python interface to load the npz format model.
In particular, we consider that most of the pre-trained models are in pytorch format and used with python. We provide a shortcut for calling directly in python for the pytorch saved model.
-
+
-#### python APIs
+#### Bert Examples
+##### python APIs
Refer to examples in [./example/python](./example/python "python").
Since the user of BERT acceleration always requires a customized post-processing process for the task, we provide an example of how to write a sequence classification application.
-#### C++ APIs
+##### C++ APIs
Refer to [./example/cpp](./example/cpp "C ++") for an example.
Our example provides the GPU and two CPU multi-thread calling methods. One is to do one BERT inference using multiple threads; the other is to do multiple BERT inference, each of which using one thread.
Users can link turbo-transformers to your code through add_subdirectory.
-## Performance
-### CPU
-We tested the performance of TurboTransformers on three CPU hardware platforms.
-We choose [pytorch](https://github.com/huggingface "pytorch"), [pytorch-jit](https://pytorch.org/docs/stable/_modules/torch/jit.html "pytorch-jit" ) and [onnxruntime-mkldnn](https://github.com/microsoft/onnxruntime "onnxruntime-mkldnn") and TensorRT implementation as a comparison. The performance test result is the average of 150 iterations. In order to avoid the phenomenon that the data of the last iteration is cached in the cache during multiple tests, each test uses random data and refreshes the cache data after calculation.
-* Intel Xeon 61xx
-
-
-
-
-* Intel Xeon 6133
-Compared to the 61xx model, Intel Xeon 6133 has a longer vectorized length of 512 bits, and it has a 30 MB shared L3 cache between cores.
-
-
-
-
-### GPU
-We tested the performance of turbo_transformers on four GPU hardware platforms.
-We choose [pytorch](https://github.com/huggingface "pytorch"), [NVIDIA Faster Transformers](https://github.com/NVIDIA/DeepLearningExamples/tree/master/FasterTransformer "FasterTransformer"), [onnxruntime-gpu](https://github.com/microsoft/onnxruntime "onnxrt-gpu") and [TensorRT](https://github.com/NVIDIA/TensorRT/tree/release/6.0/demo/BERT) implementation as a comparison. The performance test result is the average of 150 iterations.
-
-* RTX 2060
-
-
+#### Decoder Examples
+[TurboNLP/Translate-Demo](https://github.com/TurboNLP/Translate-Demo "translate") shows a demo of applying TurboTransformer in Translatetion Task.
-* Tesla V100
-
-
-
-
-* Tesla P40
-
-
-
+## Performance
+[BERT Benchmark Results](./docs/bert.md)
-* Tesla M40
+[Transformer Docoder Results](./docs/decoder.md)
-
-
+[How to know hotspots of your code](./docs/profiler.md)
## TODO
-Currently (April 2020), we only support a interface of the BERT encoder model using FP32. In the near futuer, we will add support for other models (GPT2, decoders, etc.) and low-precision floating point (CPU int8, GPU FP16).
+Currently (June 2020), In the near futuer, we will add support for other models (Albert [Work In Progress], GPT2) and low-precision floating point (CPU int8, GPU FP16).
+**Looking forwards to your contribution!**
## Lisence
BSD 3-Clause License
@@ -162,11 +138,10 @@ BSD 3-Clause License
1. The results of Turbo Transformers may be different from the results of PyTorch after 2 digits behind the decimal point.
The diff mainly comes from Bert Output Layer. We use a approximate GELU algorithm, which may be different from PyTorch.
-2. On AuthenticAMD CPU, member function `from_torch` of class `BertModelWithPooler` and `BertModel` does not support PyTorch version as 1.5.0.
-In our opinion, the tensor transpose API of PyTorch is not stable. We use the following way to transpose weight matrices.
-```
-weight = torch.clone(torch.t(pooler_params['dense.weight']))
-```
+### History
+1. April 2020 v0.0.1, TurboTransformers released, and achieved state-of-the-art BERT inference speed on CPU/GPU.
+2. June 2020 v0.2.1, TurboTransformers add BLIS as a BLAS option. Better performance on AMD CPU.
+3. June 2020 v0.3.0, TurboTransformers adds support for Transformer Decoder on CPU/GPU.
## Contact us
Although we recommand you post your problem with github issues, you can also join in our Turbo user group.
diff --git a/benchmark/run_gpu_benchmark.sh b/benchmark/run_gpu_benchmark.sh
index 0ec0a3ae..44c5ad81 100644
--- a/benchmark/run_gpu_benchmark.sh
+++ b/benchmark/run_gpu_benchmark.sh
@@ -18,6 +18,7 @@ FRAMEWORKS=("turbo-transformers" "torch")
# FRAMEWORKS=("onnxruntime")
SEQ_LEN=(10 20 40 60 80 100 200 300 400 500)
BATCH_SIZE=(1 20)
+
N=150
MODEL="bert-base-chinese"
for batch_size in ${BATCH_SIZE[*]}
diff --git a/benchmark/turbo_transformers/layers/kernels/matmul_benchmark.cpp b/benchmark/turbo_transformers/layers/kernels/matmul_benchmark.cpp
index 4292faf1..3794e2d6 100644
--- a/benchmark/turbo_transformers/layers/kernels/matmul_benchmark.cpp
+++ b/benchmark/turbo_transformers/layers/kernels/matmul_benchmark.cpp
@@ -27,7 +27,7 @@ using layers::kernels::common::FillRandom;
static void MatmulBenchmarkHelper(DLDeviceType device_type, bool trans_weight,
std::initializer_list weight_shape,
std::vector m_list) {
- constexpr int n_step = 1000;
+ constexpr int n_step = 100;
const std::string device_name = device_type == kDLCPU ? "CPU" : "GPU";
const std::string trans_name = trans_weight ? "Tran" : "NoTrans";
@@ -61,27 +61,15 @@ static void MatmulBenchmarkHelper(DLDeviceType device_type, bool trans_weight,
ss << device_name << " " << trans_name << " MatMul " << m << ", " << k
<< ", " << n << " ";
auto g_flops = m * n * k * 2 / 1e9;
-
- if (device_type == kDLGPU) {
-#ifdef TT_WITH_CUDA
- auto flops = benchmark::TestFuncSpeed(
- [&]() {
- layers::kernels::MatMul(input_tensor, false, weight_tensor,
- trans_weight, 1.0, &output_tensor, 0.0);
- },
- n_step, ss.str(), g_flops, device_type);
-
- std::cout << ss.str() << " flops: " << flops << std::endl;
-#endif
- } else {
- benchmark::TestFuncSpeed(
- [&]() {
- layers::kernels::MatMul(input_tensor, false, weight_tensor,
- trans_weight, 1.0, &output_tensor, 0.0);
- },
- n_step, ss.str(), g_flops, device_type);
- }
- }
+ auto flops = benchmark::TestFuncSpeed(
+ [&]() {
+ layers::kernels::MatMul(input_tensor, false, weight_tensor,
+ trans_weight, 1.0, &output_tensor, 0.0);
+ },
+ n_step, ss.str(), g_flops, device_type);
+
+ std::cout << ss.str() << " flops: " << flops << std::endl;
+ } // for
}
TEST_CASE("matmal-cpu-benchmark") {
@@ -93,6 +81,62 @@ TEST_CASE("matmal-cpu-benchmark") {
std::cout << std::endl;
}
+static void MatmulBenchmarkGeneralHelper(DLDeviceType device_type,
+ bool trans_weight,
+ std::vector dim_list) {
+ constexpr int n_step = 1000;
+ const std::string device_name = device_type == kDLCPU ? "CPU" : "GPU";
+ const std::string trans_name = trans_weight ? "Trans" : "NoTrans";
+
+ for (auto m : dim_list) {
+ std::initializer_list input_shape{m, m};
+ std::initializer_list weight_shape{m, m};
+ std::initializer_list output_shape{m, m};
+
+ using turbo_transformers::core::NewDLPackTensorT;
+
+ core::Tensor input_tensor(
+ NewDLPackTensorT(input_shape, device_type, 0));
+ FillRandom(input_tensor);
+
+ core::Tensor weight_tensor(
+ NewDLPackTensorT(weight_shape, device_type, 0));
+ FillRandom(weight_tensor);
+
+ core::Tensor output_tensor(
+ NewDLPackTensorT(output_shape, device_type, 0));
+ FillRandom(output_tensor);
+
+ std::stringstream ss;
+ ss << device_name << " " << trans_name << " MatMul " << m << ", " << m
+ << ", " << m << " ";
+ auto g_flops = m * m * m * 2 / 1e9;
+ auto flops = benchmark::TestFuncSpeed(
+ [&]() {
+ layers::kernels::MatMul(input_tensor, false, weight_tensor,
+ trans_weight, 1.0, &output_tensor, 0.0);
+ },
+ n_step, ss.str(), g_flops, device_type);
+
+ std::cout << ss.str() << " flops: " << flops << std::endl;
+ } // for
+}
+
+TEST_CASE("matmal-cpu-benchmark-general") {
+#if defined(TT_BLAS_USE_MKL)
+ std::cout << "blas uses MKL" << std::endl;
+#elif defined(TT_BLAS_USE_OPENBLAS)
+ std::cout << "blas uses OpenBLAS" << std::endl;
+#elif defined(TT_BLAS_USE_BLIS)
+ std::cout << "blas uses BLIS" << std::endl;
+#endif
+ std::cout << "=================================" << std::endl;
+ std::cout << "CPU General MatMul Benchmark" << std::endl;
+ std::vector dim_list{10, 50, 100, 500, 1000, 1500, 2000};
+ MatmulBenchmarkGeneralHelper(kDLCPU, false, dim_list);
+ std::cout << std::endl;
+}
+
#ifdef TT_WITH_CUDA
TEST_CASE("matmal-gpu-gemm7-benchmark") {
diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake
index c21cb09b..8e74829a 100644
--- a/cmake/cuda.cmake
+++ b/cmake/cuda.cmake
@@ -25,4 +25,4 @@ foreach(X ${ARCH_FLAGS})
endforeach()
message(STATUS "Generating CUDA code for ${CUDA_VERSION} SMs: ${CUDA_FLAGS}")
-set(CMAKE_CUDA_FLAGS "${CUDA_FLAGS} -Xcompiler -Wall -std=c++11 --expt-relaxed-constexpr --use_fast_math --expt-extended-lambda")
+set(CMAKE_CUDA_FLAGS "${CUDA_FLAGS} -Xcompiler -Wall --expt-relaxed-constexpr --use_fast_math --expt-extended-lambda")
diff --git a/docs/bert.md b/docs/bert.md
new file mode 100644
index 00000000..12c24c0b
--- /dev/null
+++ b/docs/bert.md
@@ -0,0 +1,38 @@
+We show BERT inference performance here.
+
+### CPU
+We tested the performance of TurboTransformers on three CPU hardware platforms.
+We choose [pytorch](https://github.com/huggingface "pytorch"), [pytorch-jit](https://pytorch.org/docs/stable/_modules/torch/jit.html "pytorch-jit" ) and [onnxruntime-mkldnn](https://github.com/microsoft/onnxruntime "onnxruntime-mkldnn") and TensorRT implementation as a comparison. The performance test result is the average of 150 iterations. In order to avoid the phenomenon that the data of the last iteration is cached in the cache during multiple tests, each test uses random data and refreshes the cache data after calculation.
+* Intel Xeon 61xx
+
+
+
+
+* Intel Xeon 6133
+Compared to the 61xx model, Intel Xeon 6133 has a longer vectorized length of 512 bits, and it has a 30 MB shared L3 cache between cores.
+
+
+
+
+### GPU
+We tested the performance of turbo_transformers on four GPU hardware platforms.
+We choose [pytorch](https://github.com/huggingface "pytorch"), [NVIDIA Faster Transformers](https://github.com/NVIDIA/DeepLearningExamples/tree/master/FasterTransformer "FasterTransformer"), [onnxruntime-gpu](https://github.com/microsoft/onnxruntime "onnxrt-gpu") and [TensorRT](https://github.com/NVIDIA/TensorRT/tree/release/6.0/demo/BERT) implementation as a comparison. The performance test result is the average of 150 iterations.
+
+* RTX 2060
+
+
+
+* Tesla V100
+
+
+
+
+* Tesla P40
+
+
+
+
+* Tesla M40
+
+
+
diff --git a/docs/decoder.md b/docs/decoder.md
new file mode 100644
index 00000000..a681dfa9
--- /dev/null
+++ b/docs/decoder.md
@@ -0,0 +1,6 @@
+We show Transformer Decoder inference performance here.
+
+For a translation demo [TurboNLP/Translate-Demo](https://github.com/TurboNLP/Translate-Demo "translate"),
+Turbo will bring 15.9% performance improvements on RTX 2060 GPU.
+
+We are still working on decoder model optimization.
diff --git a/docs/profiler.md b/docs/profiler.md
new file mode 100644
index 00000000..a0ab39f8
--- /dev/null
+++ b/docs/profiler.md
@@ -0,0 +1,68 @@
+## How to profile you code
+1. Compiling code before setting option WITH_PROFILER ON in CMakeList.txt
+
+```
+option(WITH_PROFILER "Compile with profiler" ON)
+```
+2. Add profiling context in your code, for example
+
+```
+with turbo_transformers.pref_guard("info") as perf:
+ dec_out, dec_attn = self.turbo_decoder(
+ decoder_in, memory_bank, memory_lengths=memory_lengths, step=step
+ )
+```
+
+3. The profiling results will be shown on your screen, like this
+```
+info Time line:
+context/values/AddBiasTransposeForScore/reshape , 0.023328, 0.00835687 %
+context/keys/AddBiasTransposeForScore/Reshape , 0.030464, 0.0109132 %
+context/gemm2/k_out1/Reshape , 0.040384, 0.0144669 %
+context/keys/AddBiasTransposeForScore , 0.04688, 0.016794 %
+context/gemm1/v_out1/Reshape , 0.049568, 0.0177569 %
+context/values/AddBiasTransposeForScore , 0.050304, 0.0180206 %
+context/gemm2 , 0.300832, 0.107768 %
+context/gemm1 , 0.322112, 0.115391 %
+context/AddBiasTransposeForScore/q_out2/Reshape , 0.515776, 0.184768 %
+context/AddBiasTransposeForScore/q_out1/Reshape , 0.616032, 0.220683 %
+context/gemm0/q_out2/Reshape , 0.773504, 0.277095 %
+self/self_value/Reshape , 0.794784, 0.284718 %
+self/layernorm/Reshape , 0.801984, 0.287297 %
+FFN/Reshape , 0.851904, 0.30518 %
+self/self_key/Reshape , 0.986688, 0.353464 %
+FFN/AddBiasAct , 1.28634, 0.460808 %
+context/gemm0/q_out1/Reshape , 1.35478, 0.485329 %
+self/qkv_out1/Reshape , 1.5048, 0.539069 %
+context/AddBiasTransposeForScore , 1.5057, 0.53939 %
+self/SplitAddBiasTransposeForScore , 1.56646, 0.561159 %
+FFN/LayerNorm , 1.64701, 0.590013 %
+gemm5/Reshape , 1.65885, 0.594255 %
+context/gemm0/prelayernorm , 1.66512, 0.596501 %
+LayerNorm , 1.6856, 0.603838 %
+self/self_value/Copy , 1.68688, 0.604297 %
+batch_gemm4/Reshape , 1.69667, 0.607804 %
+Concat/Reshape , 1.796, 0.643387 %
+ApplyMaskAndSoftmax/Reshape , 1.80499, 0.646608 %
+batch_gemm3/Reshape , 2.03645, 0.729523 %
+Reshape , 2.1289, 0.762641 %
+self/layernorm/Copy , 2.53923, 0.909637 %
+self/qkv_out2/Reshape , 2.65715, 0.95188 %
+Concat , 2.76022, 0.988804 %
+context/gemm0/prelayernorm/Copy , 2.83021, 1.01387 %
+batch_gemm4 , 3.00442, 1.07628 %
+self/self_key/Copy , 3.07203, 1.1005 %
+batch_gemm3 , 3.34592, 1.19862 %
+ApplyMaskAndSoftmax , 3.67014, 1.31477 %
+TransposeForScore , 3.76816, 1.34988 %
+FFN/AddInputBias , 3.97325, 1.42335 %
+self/keys/Concat , 4.66528, 1.67126 %
+self/values/Concat , 5.08947, 1.82322 %
+context/gemm0 , 5.76464, 2.06509 %
+self/gemm012_fused , 7.82285, 2.8024 %
+gemm5 , 11.295, 4.04626 %
+FFN/gemm0 , 12.2295, 4.381 %
+FFN/gemm1 , 17.4551, 6.25299 %
+MultiHeadedAttention_context , 60.8736, 21.8069 %
+MultiHeadedAttention_self , 91.1025, 32.6359 %
+```
diff --git a/example/cpp/bert_model_example.cpp b/example/cpp/bert_model_example.cpp
index 335094fd..f6d3010a 100644
--- a/example/cpp/bert_model_example.cpp
+++ b/example/cpp/bert_model_example.cpp
@@ -11,8 +11,6 @@
// permissions and limitations under the License.
// See the AUTHORS file for names of contributors.
-#include "bert_model.h"
-
#include
#include
#include
@@ -20,6 +18,7 @@
#include
#include
+#include "bert_model.h"
#include "turbo_transformers/core/config.h"
static bool test(const std::string &model_path, bool use_cuda = false) {
@@ -54,9 +53,13 @@ static std::vector CallBackFunction(
}
bool test_multiple_threads(const std::string &model_path, bool only_input,
- int n_threads) {
+ bool use_cuda, int n_threads) {
std::shared_ptr model_ptr =
std::make_shared(model_path, DLDeviceType::kDLCPU, 12, 12);
+ // input_ids, position_ids, segment_ids lengths of each row may not be the
+ // same. For example. std::vector> input_ids{{1, 2, 3, 4,
+ // 5, 6, 7},
+ // {1, 2}};
std::vector> input_ids{{12166, 10699, 16752, 4454},
{5342, 16471, 817, 16022}};
std::vector> position_ids{{1, 0, 0, 0}, {1, 1, 1, 0}};
@@ -94,6 +97,10 @@ bool test_multiple_threads(const std::string &model_path, bool only_input,
// bert-base-uncased (2020.04.23 version), you may need to change it to
// real-time values.
if (only_input) {
+ std::cerr << vec.data()[0] << std::endl;
+ std::cerr << vec.data()[1] << std::endl;
+ std::cerr << vec.data()[768] << std::endl;
+ std::cerr << vec.data()[768 + 1] << std::endl;
assert(fabs(vec.data()[0] - -0.1901) < 1e-3);
assert(fabs(vec.data()[1] - 0.0193) < 1e-3);
assert(fabs(vec.data()[768] - 0.3060) < 1e-3);
@@ -132,8 +139,10 @@ int main(int argc, char *argv[]) {
turbo_transformers::core::SetNumThreads(1);
if (core::IsCompiledWithCUDA()) {
std::cout << "10 threads do 10 independent bert inferences." << std::endl;
- test_multiple_threads(model_path, true /*use cuda*/, 10);
+ test_multiple_threads(model_path, false /*only_input*/, true /*use cuda*/,
+ 10);
}
- test_multiple_threads(model_path, false /*not use cuda*/, 10);
+ test_multiple_threads(model_path, false /*only_input*/,
+ false /*not use cuda*/, 10);
return 0;
}
diff --git a/tools/convert_huggingface_bert_pytorch_to_npz.py b/tools/convert_huggingface_bert_pytorch_to_npz.py
index 87ae8889..a1ff2c7f 100644
--- a/tools/convert_huggingface_bert_pytorch_to_npz.py
+++ b/tools/convert_huggingface_bert_pytorch_to_npz.py
@@ -50,7 +50,7 @@ def main():
arrays[k],
arrays[k[:-len(q_weight_key)] + k_weight_key],
arrays[k[:-len(q_weight_key)] + v_weight_key]
- ], 0)))
+ ], 0).contiguous()).contiguous())
numpy_dict[k[:-len(q_weight_key)] + "qkv.weight"] = v.numpy()
elif k.endswith(q_bias_key):
v = torch.cat([
@@ -65,7 +65,8 @@ def main():
or k.endswith("pooler.dense.weight")
or (k.endswith("output.dense.weight")
or k.endswith("intermediate.dense.weight"))):
- numpy_dict[k] = torch.clone(torch.t(arrays[k])).numpy()
+ numpy_dict[k] = torch.clone(torch.t(
+ arrays[k]).contiguous()).numpy()
else:
numpy_dict[k] = arrays[k].numpy()
del arrays
diff --git a/tools/docker/Dockerfile_dev.cpu b/tools/docker/Dockerfile_dev.cpu
index fd84c3bd..9c91781e 100644
--- a/tools/docker/Dockerfile_dev.cpu
+++ b/tools/docker/Dockerfile_dev.cpu
@@ -1,6 +1,6 @@
# Build ONNX Runtime for benchmark
FROM continuumio/miniconda3 as OnnxRTBuilder
-RUN apt-get update && apt-get install -y build-essential
+RUN apt-get update && apt-get install -y build-essential gfortran
RUN /opt/conda/bin/conda install mklml cmake curl git numpy -c anaconda
RUN mkdir /src && cd /src/ && git clone -v https://github.com/microsoft/onnxruntime.git &&\
cd onnxruntime && git checkout v1.2.0 && git submodule update --init --recursive
@@ -22,6 +22,7 @@ COPY --from=PProfBuilder /go/bin/pprof /bin/pprof
# NOTE: 1. MKL is installed with pytorch.
# turbo-transformers will use the same MKL from PyTorch
RUN /opt/conda/bin/conda install pytorch==1.4.0 cpuonly -c pytorch && \
+ pip install OpenNMT-py==1.1.1 && \
/opt/conda/bin/conda install curl conda-verify conda-build mkl-include cmake -c anaconda && \
/opt/conda/bin/conda install make cmake git graphviz gperftools git-lfs docopt -c conda-forge && \
/opt/conda/bin/conda clean -afy
diff --git a/tools/docker/Dockerfile_dev.gpu b/tools/docker/Dockerfile_dev.gpu
index cfcb3014..8260f8dd 100644
--- a/tools/docker/Dockerfile_dev.gpu
+++ b/tools/docker/Dockerfile_dev.gpu
@@ -1,7 +1,7 @@
FROM IMAGE_BASE
RUN apt-get update && \
- apt-get install -y curl git wget bzip2 build-essential ninja-build g++ && rm -rf /var/lib/apt/lists/*
+ apt-get install -y curl git wget bzip2 build-essential gfortran ninja-build g++ && rm -rf /var/lib/apt/lists/*
ENV PATH=/opt/miniconda3/bin:${PATH} CONDA_PREFIX=/opt/miniconda3
RUN curl -LO http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
@@ -11,6 +11,7 @@ RUN curl -LO http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.s
conda install pytorch=PYTORCH_VERSION torchvision==0.3.0 cudatoolkit=CUDA_VERSION -c pytorch && \
conda install curl conda-verify conda-build mkl-include cmake -c anaconda && \
conda install git git-lfs docopt -c conda-forge && \
+ pip install OpenNMT-py==1.1.1 && \
conda clean -afy
RUN pip --no-cache-dir install contexttimer future transformers docopt
diff --git a/tools/docker/Dockerfile_release.cpu b/tools/docker/Dockerfile_release.cpu
index 4c314009..fdca94c3 100644
--- a/tools/docker/Dockerfile_release.cpu
+++ b/tools/docker/Dockerfile_release.cpu
@@ -8,6 +8,7 @@ ENV PATH=/opt/conda/bin/bin:${PATH} CONDA_PREFIX=/opt/conda
# NOTE: 1. MKL is installed with pytorch.
# turbo-transformers will use the same MKL from PyTorch
RUN /opt/conda/bin/conda install pytorch==1.4.0 cpuonly -c pytorch && \
+ pip install OpenNMT-py==1.1.1 && \
/opt/conda/bin/conda install curl conda-verify conda-build mkl-include cmake -c anaconda && \
/opt/conda/bin/conda install make cmake git graphviz gperftools git-lfs docopt -c conda-forge && \
/opt/conda/bin/conda clean -afy
diff --git a/turbo_transformers/core/CMakeLists.txt b/turbo_transformers/core/CMakeLists.txt
index 456cdc0f..df7a6659 100644
--- a/turbo_transformers/core/CMakeLists.txt
+++ b/turbo_transformers/core/CMakeLists.txt
@@ -22,6 +22,7 @@ add_library(tt_core
tensor.cpp
config.cpp
profiler.cpp
+ allocator.cpp
)
target_link_libraries(tt_core PUBLIC
absl::stacktrace
@@ -48,17 +49,15 @@ endif ()
if (WITH_GPU)
- target_sources(tt_core PRIVATE cuda_device_context.cpp cuda_allocator.cpp)
+ target_sources(tt_core PRIVATE cuda_device_context.cpp)
target_link_libraries(tt_core PUBLIC cudart cuda cublas)
endif()
-if (WITH_PROFILER)
- target_link_libraries(tt_core gperftools::profiler)
-endif ()
add_executable(tt_core_test
enforce_test.cpp
device_context_test.cpp
tensor_test.cpp
+ allocator_test.cpp
fp16_test.cpp)
target_link_libraries(tt_core_test catch2_test_main tt_core)
add_test(NAME tt_core_test COMMAND tt_core_test)
diff --git a/turbo_transformers/core/allocator.cpp b/turbo_transformers/core/allocator.cpp
new file mode 100644
index 00000000..7d6b071c
--- /dev/null
+++ b/turbo_transformers/core/allocator.cpp
@@ -0,0 +1,216 @@
+// Copyright (C) 2020 THL A29 Limited, a Tencent company.
+// All rights reserved.
+// Licensed under the BSD 3-Clause License (the "License"); you may
+// not use this file except in compliance with the License. You may
+// obtain a copy of the License at
+// https://opensource.org/licenses/BSD-3-Clause
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" basis,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+// implied. See the License for the specific language governing
+// permissions and limitations under the License.
+// See the AUTHORS file for names of contributors.
+
+#include "turbo_transformers/core/allocator.h"
+
+#include
+
+#ifdef TT_WITH_CUDA
+#include
+
+#include
+#include
+
+#include "turbo_transformers/core/cuda_device_context.h"
+#include "turbo_transformers/core/cuda_enforce.cuh"
+#endif
+
+namespace turbo_transformers {
+namespace core {
+
+struct BadAlloc : public std::exception {
+ explicit BadAlloc(std::string err_msg) : err_str_(err_msg) {}
+
+ const char *what() const noexcept override { return err_str_.c_str(); }
+
+ std::string err_str_;
+};
+
+#ifdef TT_WITH_CUDA
+static void *cuda_alloc(size_t sz) {
+ void *device_mem;
+ try {
+ TT_ENFORCE_CUDA_SUCCESS(cudaMalloc((void **)&(device_mem), sz));
+ } catch (...) {
+ throw BadAlloc("cudaMalloc failed.");
+ }
+ return device_mem;
+}
+
+static void cuda_free(void *data) { TT_ENFORCE_CUDA_SUCCESS(cudaFree(data)); }
+#endif
+
+namespace {
+void *allocate_impl(size_t size, DLDeviceType dev) {
+ if (kDLCPU == dev) {
+ return align_alloc(size);
+ } else if (kDLGPU == dev) {
+#ifdef TT_WITH_CUDA
+ auto addr = cuda_alloc(size);
+ return addr;
+#endif
+ } else {
+ TT_THROW("Not supported devtype");
+ }
+ return nullptr;
+}
+
+void free_impl(void *memory_addr, DLDeviceType dev) {
+ if (kDLCPU == dev) {
+ free(memory_addr);
+ } else if (kDLGPU == dev) {
+#ifdef TT_WITH_CUDA
+ cuda_free(memory_addr);
+#endif
+ } else {
+ TT_THROW("Not supported devtype");
+ }
+}
+
+} // namespace
+
+struct Allocator::BestFitAllocatorImpl {
+ public:
+ void free_cache(size_t size, DLDeviceType dev) {
+ if (size == 0) return;
+ size_t cur = 0;
+ while (!allocations_.empty()) { // free the largest
+ auto it = --allocations_.end();
+ cur += it->first;
+ free_impl(it->second, dev);
+ addr_size_map_.erase(it->second);
+ allocation_size_ -= it->first;
+ allocations_.erase(it);
+ if (cur >= size) return;
+ }
+ }
+
+ void *alloc(size_t size, DLDeviceType dev) {
+ auto it = allocations_.lower_bound(size);
+ void *allocated_addr;
+ if (it != allocations_.end() && it->first >= size) {
+ allocated_addr = it->second;
+ allocations_.erase(it);
+ } else {
+ try {
+ allocated_addr = allocate_impl(size, dev);
+ } catch (BadAlloc &) {
+ free_cache(size, dev);
+ allocated_addr = allocate_impl(size, dev);
+ }
+ }
+
+ addr_size_map_[allocated_addr] = size;
+ return allocated_addr;
+ }
+
+ void free(void *data, DLDeviceType dev) {
+ auto size = addr_size_map_[data];
+ allocations_.emplace(size, data);
+ allocation_size_ += size;
+ addr_size_map_.erase(data);
+ }
+
+ BestFitAllocatorImpl() : allocation_size_(0) {}
+
+ private:
+ std::multimap allocations_;
+ std::unordered_map addr_size_map_;
+ size_t allocation_size_;
+}; // struct Allocator::BestFitAllocatorImpl
+
+struct Allocator::CachingAllocatorImpl {
+ void *alloc(size_t size, DLDeviceType dev) {
+ void *data = nullptr;
+ if (dev == kDLCPU) {
+ return allocate_impl(size, kDLCPU);
+ } else if (dev == kDLGPU) {
+#ifdef TT_WITH_CUDA
+ static auto stream = core::CUDADeviceContext::GetInstance().stream();
+ try {
+ cudaError_t result = cub_allocator.DeviceAllocate(&data, size, stream);
+ if (result != cudaSuccess) {
+ throw BadAlloc("DeviceAllocate failed.");
+ }
+ } catch (...) {
+ cub_allocator.FreeAllCached();
+ cudaError_t result = cub_allocator.DeviceAllocate(&data, size, stream);
+ if (result != cudaSuccess) {
+ std::stringstream ss;
+ ss << "DeviceAllocate failed Again. " << size;
+ throw BadAlloc(ss.str());
+ }
+ }
+#endif
+ }
+ return data;
+
+ } // alloc
+
+ void free(void *data, DLDeviceType dev) {
+ if (dev == kDLCPU) {
+ free_impl(data, kDLCPU);
+ } else if (dev == kDLGPU) {
+#ifdef TT_WITH_CUDA
+ try {
+ cudaError_t result = cub_allocator.DeviceFree(data);
+ if (result != cudaErrorCudartUnloading && result != cudaSuccess) {
+ throw std::runtime_error("DeviceFree failed ");
+ }
+ } catch (...) {
+ }
+#endif
+ }
+ }
+
+ private:
+#ifdef TT_WITH_CUDA
+ cub::CachingDeviceAllocator cub_allocator;
+#endif
+}; // struct Allocator::CachingAllocatorImpl
+
+/*********
+ * APIs of Allocator
+ *********/
+Allocator::Allocator()
+ : bestfit_allocator_(new BestFitAllocatorImpl()),
+ caching_allocator_(new CachingAllocatorImpl()) {}
+Allocator::~Allocator() = default;
+
+void *Allocator::allocate(size_t size, const std::string &strategy,
+ DLDeviceType dev) {
+ if (dev == kDLCPU) {
+ return allocate_impl(size, dev);
+ }
+ if ("bestfit" == strategy) {
+ return bestfit_allocator_->alloc(size, dev);
+ } else if ("cub" == strategy) {
+ return caching_allocator_->alloc(size, dev);
+ }
+ return nullptr;
+}
+
+void Allocator::free(void *memory, const std::string &strategy,
+ DLDeviceType dev) {
+ if (dev == kDLCPU) {
+ return free_impl(memory, dev);
+ }
+ if ("bestfit" == strategy) {
+ bestfit_allocator_->free(memory, dev);
+ } else if ("cub" == strategy) {
+ caching_allocator_->free(memory, dev);
+ }
+}
+
+} // namespace core
+} // namespace turbo_transformers
diff --git a/turbo_transformers/core/cuda_allocator.h b/turbo_transformers/core/allocator.h
similarity index 62%
rename from turbo_transformers/core/cuda_allocator.h
rename to turbo_transformers/core/allocator.h
index 36df1247..23787e78 100644
--- a/turbo_transformers/core/cuda_allocator.h
+++ b/turbo_transformers/core/allocator.h
@@ -16,32 +16,34 @@
#include