From 68b8f723bdf8fe255f00ce4fc5b78c33aaec4ff4 Mon Sep 17 00:00:00 2001 From: feifeibear Date: Tue, 24 Nov 2020 21:49:20 +0800 Subject: [PATCH] Develop (#195) Add distilBERT --- CMakeLists.txt | 3 +- benchmark/benchmark.py | 3 - benchmark/onnx_benchmark_helper.py | 4 +- benchmark/torch_benchmark_helper.py | 7 +- benchmark/turbo_benchmark_helper.py | 10 +- distrill/bert_model.txt | 0 distrill/distill_bert.txt | 0 distrill/distrill_bert.py | 45 +++ requirements.txt | 2 +- tools/docker/Dockerfile_release.cpu | 2 +- tools/docker/Dockerfile_release.gpu | 4 +- .../layers/positionwise_ffn.cpp | 48 +++ turbo_transformers/layers/positionwise_ffn.h | 30 ++ turbo_transformers/python/pybind.cpp | 12 + .../tests/distill_bert_attention_test.py | 117 +++++++ .../python/tests/distill_bert_ffn_test.py | 110 ++++++ .../python/tests/distill_bert_model_test.py | 111 +++++++ .../tests/distill_transformer_block_test.py | 109 ++++++ .../python/tests/distill_transformer_test.py | 112 +++++++ .../turbo_transformers/layers/__init__.py | 5 +- .../layers/modeling_distillbert.py | 314 ++++++++++++++++++ 21 files changed, 1031 insertions(+), 17 deletions(-) create mode 100644 distrill/bert_model.txt create mode 100644 distrill/distill_bert.txt create mode 100644 distrill/distrill_bert.py create mode 100644 turbo_transformers/python/tests/distill_bert_attention_test.py create mode 100644 turbo_transformers/python/tests/distill_bert_ffn_test.py create mode 100644 turbo_transformers/python/tests/distill_bert_model_test.py create mode 100644 turbo_transformers/python/tests/distill_transformer_block_test.py create mode 100644 turbo_transformers/python/tests/distill_transformer_test.py create mode 100644 turbo_transformers/python/turbo_transformers/layers/modeling_distillbert.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 961b3284..b8231e3d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,7 +21,8 @@ set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_FLAGS "-Wall") set(CMAKE_C_FLAGS "-Wall") -set(TURBO_TRANSFORMERS_VERSION 0.5.0) + +set(TURBO_TRANSFORMERS_VERSION 0.5.1) option(WITH_PROFILER "Compile with profiler" OFF) diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 780769d4..3227e0cc 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -30,9 +30,6 @@ --enable_mem_opt Use model aware memory optimization for BERT. """ -import json -import os - import docopt from turbo_benchmark_helper import benchmark_turbo_transformers from torch_benchmark_helper import benchmark_torch diff --git a/benchmark/onnx_benchmark_helper.py b/benchmark/onnx_benchmark_helper.py index fff7afe4..6b4ac69f 100644 --- a/benchmark/onnx_benchmark_helper.py +++ b/benchmark/onnx_benchmark_helper.py @@ -26,7 +26,6 @@ def generate_onnx_model(model_name: str, use_dynamic_axes: bool = False): import transformers import torch - import os test_device = torch.device( 'cuda:0') if backend == "GPU" and use_gpu else torch.device('cpu:0') @@ -45,6 +44,9 @@ def generate_onnx_model(model_name: str, elif model_name == "roberta": cfg = transformers.RobertaConfig() model = transformers.RobertaModel(cfg) + elif model_name == "distilbert": + cfg = transformers.DistilBertConfig() + model = transformers.DistilBertModel(cfg) else: raise (f"benchmark does not support {model_name}") diff --git a/benchmark/torch_benchmark_helper.py b/benchmark/torch_benchmark_helper.py index 784d4269..5b960df3 100644 --- a/benchmark/torch_benchmark_helper.py +++ b/benchmark/torch_benchmark_helper.py @@ -22,10 +22,6 @@ def benchmark_torch(model_name: str, seq_len: int, batch_size: int, n: int, import benchmark_helper test_device = torch.device('cuda:0') if use_gpu else torch.device('cpu:0') - if use_gpu: - print("using GPU") - else: - print("using CPU") torch.set_grad_enabled(False) torch.set_num_threads(num_threads) @@ -39,6 +35,9 @@ def benchmark_torch(model_name: str, seq_len: int, batch_size: int, n: int, elif model_name == "roberta": cfg = transformers.RobertaConfig() model = transformers.RobertaModel(cfg) + elif model_name == "distilbert": + cfg = transformers.DistilBertConfig() + model = transformers.DistilBertModel(cfg) else: raise (f"benchmark does not support {model_name}") model.eval() diff --git a/benchmark/turbo_benchmark_helper.py b/benchmark/turbo_benchmark_helper.py index 8871bb5c..fbab9e38 100644 --- a/benchmark/turbo_benchmark_helper.py +++ b/benchmark/turbo_benchmark_helper.py @@ -24,10 +24,6 @@ def benchmark_turbo_transformers(model_name: str, seq_len: int, import turbo_transformers import benchmark_helper test_device = torch.device('cuda:0') if use_gpu else torch.device('cpu:0') - if use_gpu: - print("using GPU") - else: - print("using CPU") cfg = None torch.set_grad_enabled(False) if model_name == "bert": @@ -48,6 +44,12 @@ def benchmark_turbo_transformers(model_name: str, seq_len: int, model.to(test_device) model.eval() model = turbo_transformers.RobertaModel.from_torch(model) + elif model_name == "distilbert": + cfg = transformers.DistilBertConfig() + model = transformers.DistilBertModel(cfg) + model.to(test_device) + model.eval() + model = turbo_transformers.DistilBertModel.from_torch(model) else: raise (f"benchmark does not support {model_name}") diff --git a/distrill/bert_model.txt b/distrill/bert_model.txt new file mode 100644 index 00000000..e69de29b diff --git a/distrill/distill_bert.txt b/distrill/distill_bert.txt new file mode 100644 index 00000000..e69de29b diff --git a/distrill/distrill_bert.py b/distrill/distrill_bert.py new file mode 100644 index 00000000..a105c519 --- /dev/null +++ b/distrill/distrill_bert.py @@ -0,0 +1,45 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +from transformers import DistilBertTokenizer, DistilBertModel +from transformers import BertTokenizer, BertModel +import torch + +tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") +inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") +# inputs = torch.randint(low=0, +# high=cfg.vocab_size - 1, +# size=(1, 10), +# dtype=torch.long, +# device=torch.device("cpu:0")) + +## distrillation model +model = DistilBertModel.from_pretrained("distilbert-base-uncased", + return_dict=True) + +## bert model +bert_model = BertModel.from_pretrained("bert-base-uncased", return_dict=True) + +cfg = model.config +print(cfg) +print(inputs) +outputs = model(**inputs) +bert_outputs = bert_model(**inputs) + +print(model) +print(bert_model) + +# print(bert_outputs - outputs) +# +# last_hidden_states = outputs.last_hidden_state +# print(last_hidden_states) diff --git a/requirements.txt b/requirements.txt index d4b250f5..f020e066 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,4 @@ contexttimer onnx future -transformers==3.0.2 +transformers==3.4.0 diff --git a/tools/docker/Dockerfile_release.cpu b/tools/docker/Dockerfile_release.cpu index 4b96e547..302b6be9 100644 --- a/tools/docker/Dockerfile_release.cpu +++ b/tools/docker/Dockerfile_release.cpu @@ -14,5 +14,5 @@ RUN /opt/conda/bin/conda install pytorch==1.5.0 cpuonly -c pytorch && \ /opt/conda/bin/conda install make cmake git graphviz gperftools git-lfs docopt -c conda-forge && \ /opt/conda/bin/conda clean -afy -RUN pip --no-cache-dir install contexttimer future transformers==3.0.2 docopt onnxruntime-tools +RUN pip --no-cache-dir install contexttimer future transformers==3.4.0 docopt onnxruntime-tools WORKDIR /workspace diff --git a/tools/docker/Dockerfile_release.gpu b/tools/docker/Dockerfile_release.gpu index 82eb5a44..6ab7ea84 100644 --- a/tools/docker/Dockerfile_release.gpu +++ b/tools/docker/Dockerfile_release.gpu @@ -13,7 +13,9 @@ RUN curl -LO https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3- conda install pytorch=PYTORCH_VERSION cudatoolkit=CUDA_VERSION cudnn --freeze-installed -c pytorch && \ conda clean -yfa -RUN pip --no-cache-dir install contexttimer future transformers==3.0.2 docopt OpenNMT-py==1.1.0 onnxruntime-gpu==1.3.0 + +RUN pip --no-cache-dir install contexttimer future transformers==3.4.0 docopt OpenNMT-py==1.1.0 onnxruntime-gpu==1.3.0 + COPY --from=DEV_IMAGE /opt/miniconda3/lib/python3.7/site-packages/turbo_transformers /opt/miniconda3/lib/python3.7/site-packages/turbo_transformers diff --git a/turbo_transformers/layers/positionwise_ffn.cpp b/turbo_transformers/layers/positionwise_ffn.cpp index fbec26bf..09398ac4 100644 --- a/turbo_transformers/layers/positionwise_ffn.cpp +++ b/turbo_transformers/layers/positionwise_ffn.cpp @@ -76,5 +76,53 @@ void PositionwiseFeedForward::operator()(const core::Tensor& input_tensor, void PositionwiseFeedForward::EnforceShapeAndType() const {} +void DistrillFFN::operator()(const core::Tensor& input_tensor, + core::Tensor* output_tensor, + bool is_trans_weight) const { + auto d_ff = + is_trans_weight ? dense_weight_1_.shape(0) : dense_weight_1_.shape(1); + + auto model_dim_weight = + is_trans_weight ? dense_weight_1_.shape(1) : dense_weight_1_.shape(0); + auto model_dim = input_tensor.shape(2); + + TT_ENFORCE_EQ( + model_dim_weight, model_dim, + "dense weight and input tensor should have the same model_dim."); + + auto devType = input_tensor.device_type(); + auto devId = input_tensor.device_id(); + + // input tensor size (batch_size, input_len, model_dim) + auto batch_size = input_tensor.shape(0); + auto input_len = input_tensor.shape(1); + // allocate memory for temp data + core::Tensor input_tensor_copy(nullptr); + input_tensor_copy.Reshape({batch_size, input_len, model_dim}, devType, + devId); + core::Tensor temp_tensor(nullptr); + temp_tensor.Reshape({batch_size * input_len, d_ff}, devType, devId); + + // start computation + core::Copy(input_tensor, input_tensor_copy, "FFN/AddInputBias"); + + output_tensor->Reshape({batch_size, input_len, model_dim}, devType, + devId, "FFN/Reshape"); + kernels::MatMul(input_tensor_copy, false, dense_weight_1_, is_trans_weight, + 1.0, // input (b*seq, model) X dense_weight_1_ (model_dim, + // d_ff) -> temp_tensor (B*seq, d_ff) + &temp_tensor, 0.0, "FFN/gemm0"); + kernels::AddBiasAct( + dense_bias_1_, &temp_tensor, "FFN/AddBiasAct"); + kernels::MatMul(temp_tensor, false, dense_weight_2_, is_trans_weight, 1.0, + &input_tensor_copy, 0.0, "FFN/gemm1"); + kernels::AddInputBias(input_tensor, input_tensor_copy, dense_bias_2_, + output_tensor, "FFN/AddInputBias"); + kernels::LayerNorm(layer_norm_weight_, layer_norm_bias_, output_tensor, + 1e-12, "FFN/LayerNorm"); +} + +void DistrillFFN::EnforceShapeAndType() const {} + } // namespace layers } // namespace turbo_transformers diff --git a/turbo_transformers/layers/positionwise_ffn.h b/turbo_transformers/layers/positionwise_ffn.h index 9e091872..a3ab691d 100644 --- a/turbo_transformers/layers/positionwise_ffn.h +++ b/turbo_transformers/layers/positionwise_ffn.h @@ -14,6 +14,7 @@ #pragma once #include #include + #include "turbo_transformers/core/tensor.h" namespace turbo_transformers { @@ -51,5 +52,34 @@ class PositionwiseFeedForward { core::Tensor layer_norm_bias_; }; +class DistrillFFN { + public: + DistrillFFN(core::Tensor dense_weight_1, core::Tensor dense_bias_1, + core::Tensor dense_weight_2, core::Tensor dense_bias_2, + core::Tensor layer_norm_weight, core::Tensor layer_norm_bias) + : dense_weight_1_(std::move(dense_weight_1)), + dense_bias_1_(std::move(dense_bias_1)), + dense_weight_2_(std::move(dense_weight_2)), + dense_bias_2_(std::move(dense_bias_2)), + layer_norm_weight_(std::move(layer_norm_weight)), + layer_norm_bias_(std::move(layer_norm_bias)) { + EnforceShapeAndType(); + } + void EnforceShapeAndType() const; + + // according to profiling results on Intel 61xx, is_trans_weight = true is + // faster + void operator()(const core::Tensor &input_tensor, core::Tensor *output, + bool is_trans_weight = true) const; + + private: + core::Tensor dense_weight_1_; + core::Tensor dense_bias_1_; + core::Tensor dense_weight_2_; + core::Tensor dense_bias_2_; + core::Tensor layer_norm_weight_; + core::Tensor layer_norm_bias_; +}; + } // namespace layers } // namespace turbo_transformers diff --git a/turbo_transformers/python/pybind.cpp b/turbo_transformers/python/pybind.cpp index c6253d5d..74e17cbb 100644 --- a/turbo_transformers/python/pybind.cpp +++ b/turbo_transformers/python/pybind.cpp @@ -221,6 +221,18 @@ PYBIND11_MODULE(turbo_transformers_cxx, m) { })) .def("__call__", &layers::PositionwiseFeedForward::operator()); + py::class_(m, "DistrillFFN") + .def(py::init([](core::Tensor &dense_weight_1, core::Tensor &dense_bias_1, + core::Tensor &dense_weight_2, core::Tensor &dense_bias_2, + core::Tensor &layer_norm_weight, + core::Tensor &layer_norm_bias) -> layers::DistrillFFN * { + return new layers::DistrillFFN( + std::move(dense_weight_1), std::move(dense_bias_1), + std::move(dense_weight_2), std::move(dense_bias_2), + std::move(layer_norm_weight), std::move(layer_norm_bias)); + })) + .def("__call__", &layers::DistrillFFN::operator()); + py::class_(m, "FusedAddBiasGELU") .def(py::init([](core::Tensor &dense_bias) -> layers::FusedAddBiasGELU * { return new layers::FusedAddBiasGELU(std::move(dense_bias)); diff --git a/turbo_transformers/python/tests/distill_bert_attention_test.py b/turbo_transformers/python/tests/distill_bert_attention_test.py new file mode 100644 index 00000000..eb19c5ad --- /dev/null +++ b/turbo_transformers/python/tests/distill_bert_attention_test.py @@ -0,0 +1,117 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +import turbo_transformers + +import unittest +import sys +import torch +from transformers.modeling_distilbert import DistilBertConfig +from transformers.modeling_distilbert import MultiHeadSelfAttention as DistilAttention +from torch import nn + +import os +sys.path.append(os.path.dirname(__file__)) +import test_helper + +fname = "tt_distrill_attention.txt" + + +def create_test(batch_size, seq_length): + class TestDistillBertAttention(unittest.TestCase): + def init_data(self, use_cuda): + test_device = torch.device('cuda:0') if use_cuda else \ + torch.device('cpu:0') + if not use_cuda: + torch.set_num_threads(4) + turbo_transformers.set_num_threads(4) + + torch.set_grad_enabled(False) + self.cfg = DistilBertConfig(attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0) + self.cfg.output_attentions = True + self.torch_attention = DistilAttention(self.cfg) + self.torch_sa_layer_norm = nn.LayerNorm( + normalized_shape=self.cfg.dim, eps=1e-12) + self.torch_attention.eval() + self.torch_sa_layer_norm.eval() + if use_cuda: + self.torch_attention.to(test_device) + self.torch_sa_layer_norm.to(test_device) + + # Get FT Attention + self.turbo_attention = turbo_transformers.DistillBertAttention.from_torch( + self.torch_attention, self.torch_sa_layer_norm) + + hidden_size = self.cfg.hidden_size + self.input_tensor = torch.rand(size=(batch_size, seq_length, + hidden_size), + dtype=torch.float32, + device=test_device) + # NOTE, the mask of distilled attention is different from huggingface bert attention. + self.attention_mask = torch.ones((batch_size, seq_length), + dtype=torch.float32, + device=test_device) + + def check_torch_and_turbo(self, use_cuda, num_iter=1): + self.init_data(use_cuda) + device = "GPU" if use_cuda else "CPU" + torch_model = lambda: self.torch_sa_layer_norm( + self.torch_attention(query=self.input_tensor, + key=self.input_tensor, + value=self.input_tensor, + mask=self.attention_mask, + output_attentions=False)[0] + self. + input_tensor) + torch_attention_result, torch_qps, torch_time_consume = \ + test_helper.run_model(torch_model, use_cuda, num_iter, use_profile=False) + print( + f"DistilAttention+LN \"({batch_size},{seq_length:03})\" ", + f"{device} Torch QPS, {torch_qps}, time, {torch_time_consume}") + + turbo_model = lambda: self.turbo_attention( + self.input_tensor, + self.attention_mask, + output_attentions=self.cfg.output_attentions)[0] + + turbo_attention_result, turbo_qps, turbo_time_consume = \ + test_helper.run_model(turbo_model, use_cuda, + num_iter) + print( + f"DistilAttention \"({batch_size},{seq_length:03})\" ", + f" {device} Turbo QPS, {turbo_qps}, time, {turbo_time_consume}" + ) + + self.assertTrue( + torch.max( + torch.abs(torch_attention_result - turbo_attention_result)) + < (1e-3 if use_cuda else 1e-4)) + + def test_distillbert_attention(self): + self.check_torch_and_turbo(use_cuda=False, num_iter=1) + if torch.cuda.is_available() and \ + turbo_transformers.config.is_compiled_with_cuda(): + self.check_torch_and_turbo(use_cuda=True, num_iter=1) + + globals( + )[f"TestDistillBertAtt{batch_size}_{seq_length:3}"] = TestDistillBertAttention + + +with open(fname, "w") as fh: + fh.write(", torch, turbo_transformers\n") +for batch_size in [1, 2]: + for seq_length in [10, 20, 128]: + create_test(batch_size, seq_length) + +if __name__ == '__main__': + unittest.main() diff --git a/turbo_transformers/python/tests/distill_bert_ffn_test.py b/turbo_transformers/python/tests/distill_bert_ffn_test.py new file mode 100644 index 00000000..63af780a --- /dev/null +++ b/turbo_transformers/python/tests/distill_bert_ffn_test.py @@ -0,0 +1,110 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +import turbo_transformers + +import unittest +import sys +import torch +import os + +from transformers.modeling_distilbert import DistilBertConfig +from transformers.modeling_distilbert import FFN as DistilFFN + +sys.path.append(os.path.dirname(__file__)) +import test_helper + +fname = "distrill_ffn.txt" + + +def create_test(batch_size, input_len): + class TestDistillFFN(unittest.TestCase): + def init_data(self, use_cuda): + self.test_device = torch.device('cuda:0') if use_cuda else \ + torch.device('cpu:0') + if not use_cuda: + torch.set_num_threads(4) + turbo_transformers.set_num_threads(4) + + self.cfg = DistilBertConfig(attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0) + + torch.set_grad_enabled(False) + self.torch_ffn = DistilFFN(self.cfg) + self.torch_ffn.eval() + self.output_layer_norm = torch.nn.LayerNorm( + normalized_shape=self.cfg.dim, eps=1e-12) + if use_cuda: + self.torch_ffn.to(self.test_device) + self.output_layer_norm.to(self.test_device) + + self.turbo_ffn = turbo_transformers.DistrillFFN.from_torch( + self.torch_ffn, self.output_layer_norm) + # (batch_size, input_len, model_dim) + self.inputs = torch.rand(size=(batch_size, input_len, + self.cfg.dim), + dtype=torch.float32, + device=self.test_device) + + print(self.cfg.activation) + + def check_torch_and_turbo(self, use_cuda, num_iter=1): + self.init_data(use_cuda) + device = "GPU" if use_cuda else "CPU" + + torch_model = lambda: self.output_layer_norm( + self.torch_ffn(self.inputs) + self.inputs) + torch_res, torch_qps, torch_time_consume = \ + test_helper.run_model(torch_model, use_cuda, num_iter) + + print( + f"DistrillFFN \"({batch_size}, {input_len:03})\" ", + f"{device} Torch QPS, {torch_qps}, time, {torch_time_consume}") + + turbo_res = lambda: self.turbo_ffn(self.inputs, + is_trans_weight=True) + with turbo_transformers.pref_guard("gpref_test") as perf: + turbo_res, turbo_qps, turbo_time_consume = \ + test_helper.run_model(turbo_res, use_cuda, num_iter) + + print( + f"DistrillFFN \"({batch_size}, {input_len:03})\" ", + f"{device} Turbo Trans QPS, {turbo_qps}, time, {turbo_time_consume}" + ) + + print(torch.max(torch.abs(torch_res - turbo_res))) + self.assertTrue(torch.max(torch.abs(torch_res - turbo_res)) < 1e-3) + + with open(fname, "a") as fh: + fh.write( + f"\"({batch_size},{input_len:03})\", {torch_qps}, {turbo_qps}\n" + ) + + def test_distrill_ffn(self): + self.check_torch_and_turbo(use_cuda=False) + if torch.cuda.is_available() and \ + turbo_transformers.config.is_compiled_with_cuda(): + self.check_torch_and_turbo(use_cuda=True) + + globals()[f"TestDistillFFN{batch_size}_{input_len:3}"] = TestDistillFFN + + +with open(fname, "w") as fh: + fh.write(", torch, turbo_trans\n") + +for batch_size in [1, 4]: + for input_len in [10, 20, 30, 40, 50]: + create_test(batch_size, input_len) + +if __name__ == '__main__': + unittest.main() diff --git a/turbo_transformers/python/tests/distill_bert_model_test.py b/turbo_transformers/python/tests/distill_bert_model_test.py new file mode 100644 index 00000000..2993bb92 --- /dev/null +++ b/turbo_transformers/python/tests/distill_bert_model_test.py @@ -0,0 +1,111 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +import turbo_transformers + +import unittest +import sys +import torch +import os + +from transformers.modeling_distilbert import DistilBertConfig +from transformers.modeling_distilbert import DistilBertModel + +sys.path.append(os.path.dirname(__file__)) +import test_helper + +fname = "distrill_tbert.txt" + + +def create_test(batch_size, input_len): + class TestDistillBertModel(unittest.TestCase): + def init_data(self, use_cuda): + self.test_device = torch.device('cuda:0') if use_cuda else \ + torch.device('cpu:0') + if not use_cuda: + torch.set_num_threads(4) + turbo_transformers.set_num_threads(4) + + self.cfg = DistilBertConfig(attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0) + + torch.set_grad_enabled(False) + self.torch_model = DistilBertModel(self.cfg) + self.torch_model.eval() + if use_cuda: + self.torch_model.to(self.test_device) + + self.turbo_transformer = turbo_transformers.DistilBertModel.from_torch( + self.torch_model) + # (batch_size, input_len, model_dim) + self.inputs = torch.randint(low=0, + high=self.cfg.vocab_size - 1, + size=(batch_size, input_len), + dtype=torch.long, + device=self.test_device) + self.attention_mask = torch.ones((batch_size, input_len), + dtype=torch.long, + device=self.test_device) + self.head_mask = [None] * self.cfg.num_hidden_layers + + def check_torch_and_turbo(self, use_cuda, num_iter=1): + self.init_data(use_cuda) + device = "GPU" if use_cuda else "CPU" + + torch_model = lambda: self.torch_model(self.inputs, self. + attention_mask) + torch_res, torch_qps, torch_time_consume = \ + test_helper.run_model(torch_model, use_cuda, num_iter) + + print( + f"DistillBertModel \"({batch_size}, {input_len:03})\" ", + f"{device} Torch QPS, {torch_qps}, time, {torch_time_consume}") + + turbo_res = lambda: self.turbo_transformer( + self.inputs, self.attention_mask, head_mask=self.head_mask) + with turbo_transformers.pref_guard("gpref_test") as perf: + turbo_res, turbo_qps, turbo_time_consume = \ + test_helper.run_model(turbo_res, use_cuda, num_iter) + + print( + f"DistillBertModel \"({batch_size}, {input_len:03})\" ", + f"{device} Turbo QPS, {turbo_qps}, time, {turbo_time_consume}") + + self.assertTrue( + torch.max(torch.abs(torch_res[0] - turbo_res[0])) < 1e-2 + if use_cuda else 1e-3) + + with open(fname, "a") as fh: + fh.write( + f"\"({batch_size},{input_len:03})\", {torch_qps}, {turbo_qps}\n" + ) + + def test_distrill_bert_model(self): + self.check_torch_and_turbo(use_cuda=False) + if torch.cuda.is_available() and \ + turbo_transformers.config.is_compiled_with_cuda(): + self.check_torch_and_turbo(use_cuda=True) + + globals( + )[f"TestDistillTBertModel{batch_size}_{input_len:3}"] = TestDistillBertModel + + +with open(fname, "w") as fh: + fh.write(", torch, turbo_trans\n") + +for batch_size in [4]: + for input_len in [10]: + create_test(batch_size, input_len) + +if __name__ == '__main__': + unittest.main() diff --git a/turbo_transformers/python/tests/distill_transformer_block_test.py b/turbo_transformers/python/tests/distill_transformer_block_test.py new file mode 100644 index 00000000..81cf3968 --- /dev/null +++ b/turbo_transformers/python/tests/distill_transformer_block_test.py @@ -0,0 +1,109 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +import turbo_transformers + +import unittest +import sys +import torch +import os + +from transformers.modeling_distilbert import DistilBertConfig +from transformers.modeling_distilbert import TransformerBlock as DistilTransformerBlock + +sys.path.append(os.path.dirname(__file__)) +import test_helper + +fname = "distrill_transformer_block.txt" + + +def create_test(batch_size, input_len): + class TestDistillTransformerBlock(unittest.TestCase): + def init_data(self, use_cuda): + self.test_device = torch.device('cuda:0') if use_cuda else \ + torch.device('cpu:0') + if not use_cuda: + torch.set_num_threads(4) + turbo_transformers.set_num_threads(4) + + self.cfg = DistilBertConfig(attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0) + + torch.set_grad_enabled(False) + self.torch_transformer_block = DistilTransformerBlock(self.cfg) + self.torch_transformer_block.eval() + if use_cuda: + self.torch_transformer_block.to(self.test_device) + + self.turbo_transformer_block = turbo_transformers.DistrillTransformerBlock.from_torch( + self.torch_transformer_block) + # (batch_size, input_len, model_dim) + self.attention_mask = torch.ones((batch_size, input_len), + dtype=torch.float32, + device=self.test_device) + + self.inputs = torch.rand(size=(batch_size, input_len, + self.cfg.dim), + dtype=torch.float32, + device=self.test_device) + + def check_torch_and_turbo(self, use_cuda, num_iter=1): + self.init_data(use_cuda) + device = "GPU" if use_cuda else "CPU" + + torch_model = lambda: self.torch_transformer_block( + self.inputs, self.attention_mask) + torch_res, torch_qps, torch_time_consume = \ + test_helper.run_model(torch_model, use_cuda, num_iter) + + print( + f"DistrillTransformerBlock \"({batch_size}, {input_len:03})\" ", + f"{device} Torch QPS, {torch_qps}, time, {torch_time_consume}") + + turbo_res = lambda: self.turbo_transformer_block( + self.inputs, self.attention_mask) + with turbo_transformers.pref_guard("gpref_test") as perf: + turbo_res, turbo_qps, turbo_time_consume = \ + test_helper.run_model(turbo_res, use_cuda, num_iter) + + print( + f"DistrillTransformerBlock \"({batch_size}, {input_len:03})\" ", + f"{device} Turbo QPS, {turbo_qps}, time, {turbo_time_consume}") + + self.assertTrue( + torch.max(torch.abs(torch_res[0] - turbo_res[0])) < 1e-3) + + with open(fname, "a") as fh: + fh.write( + f"\"({batch_size},{input_len:03})\", {torch_qps}, {turbo_qps}\n" + ) + + def test_distrill_transformer_block(self): + self.check_torch_and_turbo(use_cuda=False) + if torch.cuda.is_available() and \ + turbo_transformers.config.is_compiled_with_cuda(): + self.check_torch_and_turbo(use_cuda=True) + + globals( + )[f"TestDistillTransformerBlock{batch_size}_{input_len:3}"] = TestDistillTransformerBlock + + +with open(fname, "w") as fh: + fh.write(", torch, turbo_trans\n") + +for batch_size in [1, 4]: + for input_len in [10, 20, 30, 40, 50]: + create_test(batch_size, input_len) + +if __name__ == '__main__': + unittest.main() diff --git a/turbo_transformers/python/tests/distill_transformer_test.py b/turbo_transformers/python/tests/distill_transformer_test.py new file mode 100644 index 00000000..bb0919f4 --- /dev/null +++ b/turbo_transformers/python/tests/distill_transformer_test.py @@ -0,0 +1,112 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +import turbo_transformers + +import unittest +import sys +import torch +import os + +from transformers.modeling_distilbert import DistilBertConfig +from transformers.modeling_distilbert import Transformer as DistilTransformer + +sys.path.append(os.path.dirname(__file__)) +import test_helper + +fname = "distrill_transformer.txt" + + +def create_test(batch_size, input_len): + class TestDistillTransformer(unittest.TestCase): + def init_data(self, use_cuda): + self.test_device = torch.device('cuda:0') if use_cuda else \ + torch.device('cpu:0') + if not use_cuda: + torch.set_num_threads(4) + turbo_transformers.set_num_threads(4) + + self.cfg = DistilBertConfig(attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0) + + torch.set_grad_enabled(False) + self.torch_transformer = DistilTransformer(self.cfg) + self.torch_transformer.eval() + if use_cuda: + self.torch_transformer.to(self.test_device) + + self.turbo_transformer = turbo_transformers.DistrillTransformer.from_torch( + self.torch_transformer) + # (batch_size, input_len, model_dim) + self.attention_mask = torch.ones((batch_size, input_len), + dtype=torch.float32, + device=self.test_device) + + self.inputs = torch.rand(size=(batch_size, input_len, + self.cfg.dim), + dtype=torch.float32, + device=self.test_device) + self.head_mask = [None] * self.cfg.num_hidden_layers + + def check_torch_and_turbo(self, use_cuda, num_iter=1): + self.init_data(use_cuda) + device = "GPU" if use_cuda else "CPU" + + torch_model = lambda: self.torch_transformer( + self.inputs, self.attention_mask, head_mask=self.head_mask) + torch_res, torch_qps, torch_time_consume = \ + test_helper.run_model(torch_model, use_cuda, num_iter) + + print( + f"DistillBertTransformer \"({batch_size}, {input_len:03})\" ", + f"{device} Torch QPS, {torch_qps}, time, {torch_time_consume}") + + turbo_res = lambda: self.turbo_transformer( + self.inputs, self.attention_mask, head_mask=self.head_mask) + with turbo_transformers.pref_guard("gpref_test") as perf: + turbo_res, turbo_qps, turbo_time_consume = \ + test_helper.run_model(turbo_res, use_cuda, num_iter) + + print( + f"DistillBertTransformer \"({batch_size}, {input_len:03})\" ", + f"{device} Turbo QPS, {turbo_qps}, time, {turbo_time_consume}") + + print(torch.max(torch.abs(torch_res[0] - turbo_res[0]))) + self.assertTrue( + torch.max(torch.abs(torch_res[0] - turbo_res[0])) < 1e-2 + if use_cuda else 1e-3) + + with open(fname, "a") as fh: + fh.write( + f"\"({batch_size},{input_len:03})\", {torch_qps}, {turbo_qps}\n" + ) + + def test_distrill_transformer(self): + self.check_torch_and_turbo(use_cuda=False) + if torch.cuda.is_available() and \ + turbo_transformers.config.is_compiled_with_cuda(): + self.check_torch_and_turbo(use_cuda=True) + + globals( + )[f"TestDistillTransformer{batch_size}_{input_len:3}"] = TestDistillTransformer + + +with open(fname, "w") as fh: + fh.write(", torch, turbo_trans\n") + +for batch_size in [1, 4]: + for input_len in [10, 20, 30, 40, 50]: + create_test(batch_size, input_len) + +if __name__ == '__main__': + unittest.main() diff --git a/turbo_transformers/python/turbo_transformers/layers/__init__.py b/turbo_transformers/python/turbo_transformers/layers/__init__.py index d539aa53..a3807cde 100644 --- a/turbo_transformers/python/turbo_transformers/layers/__init__.py +++ b/turbo_transformers/python/turbo_transformers/layers/__init__.py @@ -19,6 +19,7 @@ from .modeling_decoder import MultiHeadedAttention, PositionwiseFeedForward, TransformerDecoderLayer, TransformerDecoder from .modeling_roberta import RobertaModel from .modeling_gpt2 import GPT2Model +from .modeling_distillbert import DistillBertAttention, DistrillFFN, DistrillTransformerBlock, DistrillTransformer, DistilBertModel from .return_type import ReturnType @@ -33,5 +34,7 @@ 'AlbertAttention', 'AlbertTransformer', 'AlbertModel', 'PositionwiseFeedForward', 'TransformerDecoderLayer', 'TransformerDecoder', 'RobertaModel', 'QBertIntermediate', 'QBertOutput', 'QBertLayer', - 'QBertEncoder', 'QBertModel', 'GPT2Model' + 'QBertEncoder', 'QBertModel', 'GPT2Model', 'DistillBertAttention', + 'DistrillFFN', 'DistrillTransformerBlock', 'DistrillTransformer', + 'DistilBertModel' ] diff --git a/turbo_transformers/python/turbo_transformers/layers/modeling_distillbert.py b/turbo_transformers/python/turbo_transformers/layers/modeling_distillbert.py new file mode 100644 index 00000000..2a9fa8c9 --- /dev/null +++ b/turbo_transformers/python/turbo_transformers/layers/modeling_distillbert.py @@ -0,0 +1,314 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +try: + # `turbo_transformers_cxxd` is the name on debug mode + import turbo_transformers.turbo_transformers_cxxd as cxx +except ImportError: + import turbo_transformers.turbo_transformers_cxx as cxx +from typing import Union, Optional, Sequence +import torch +from .return_type import convert_returns_as_type, ReturnType +from .utils import try_convert, convert2tt_tensor, to_param_dict_convert_tt, to_param_dict, create_empty_if_none, AnyTensor + +from transformers.modeling_distilbert import DistilBertConfig +from transformers.modeling_distilbert import MultiHeadSelfAttention as TorchDistilMultiHeadSelfAttention +from transformers.modeling_distilbert import FFN as TorchDistilFFN +from transformers.modeling_distilbert import TransformerBlock as TorchDistilTransformerBlock +from transformers.modeling_distilbert import Transformer as TorchDistilTransformer +from transformers.modeling_distilbert import Embeddings as TorchDistrilEmbeddings +from transformers.modeling_distilbert import DistilBertModel as TorchDistilBertModel + +from torch import nn +import numpy as np +__all__ = [ + 'DistillBertAttention', 'DistrillFFN', 'DistrillTransformerBlock', + 'DistrillTransformer', 'DistilBertModel' +] + + +class DistillBertAttention(cxx.BertAttention): + def __call__(self, + input_tensor: AnyTensor, + attention_mask: Optional[AnyTensor] = None, + head_mask: Optional[AnyTensor] = None, + output_attentions: Optional[bool] = False, + return_type: Optional[ReturnType] = None, + is_trans_weight: Optional[cxx.Tensor] = False): + assert (head_mask is None) + # attention mask is different from BERT + if attention_mask is not None: + attention_mask = attention_mask[:, None, None, :] + attention_mask = ( + 1.0 - attention_mask) * -10000.0 #-float("inf") will cause NAN + + input_tensor = try_convert(input_tensor) + attention_mask = try_convert(create_empty_if_none(attention_mask)) + context_layer = cxx.Tensor.create_empty() + attn_probs = cxx.Tensor.create_empty() + super(DistillBertAttention, + self).__call__(input_tensor, attention_mask, context_layer, + attn_probs, is_trans_weight) + outputs = (convert_returns_as_type(context_layer, return_type), + convert_returns_as_type(attn_probs, ReturnType.TORCH) + ) if output_attentions else (convert_returns_as_type( + context_layer, return_type), ) + return outputs + + @staticmethod + def from_torch(attention: TorchDistilMultiHeadSelfAttention, + layernorm: nn.LayerNorm): + params = {k: v for k, v in attention.named_parameters()} + layernorm_params = {k: v for k, v in layernorm.named_parameters()} + + with torch.no_grad(): + # merge self.query.weight, self.query.weight and self.query.weight together as qkv.weight + qkv_weight = torch.clone( + torch.t( + torch.cat((params['q_lin.weight'], params['k_lin.weight'], + params['v_lin.weight']), + 0).contiguous()).contiguous()) + qkv_bias = torch.cat((params['q_lin.bias'], params['k_lin.bias'], + params['v_lin.bias']), 0).contiguous() + + output_weight = torch.clone( + torch.t(params['out_lin.weight']).contiguous()) + att = DistillBertAttention( + convert2tt_tensor(qkv_weight), convert2tt_tensor(qkv_bias), + convert2tt_tensor(output_weight), + convert2tt_tensor(params['out_lin.bias']), + convert2tt_tensor(layernorm_params['weight']), + convert2tt_tensor(layernorm_params['bias']), attention.n_heads) + + return att + + +class DistrillFFN(cxx.DistrillFFN): + def __call__( + self, + input_tensor: AnyTensor, + return_type: Optional[ReturnType] = None, + is_trans_weight: Optional[bool] = True, #Intel 61xx True is faster + output: Optional[cxx.Tensor] = None): + input_tensor = try_convert(input_tensor) + output = create_empty_if_none(output) + super(DistrillFFN, self).__call__(input_tensor, output, + is_trans_weight) + return convert_returns_as_type(output, return_type) + + @staticmethod + def from_torch(ffn: TorchDistilFFN, + layernorm: nn.LayerNorm, + is_trans_weight: Optional[bool] = True): + ffn_params = {k: v for k, v in ffn.named_parameters()} + layernorm_params = {k: v for k, v in layernorm.named_parameters()} + + # Note that torch's weights of linear layer is transposed + if is_trans_weight: + w_1 = convert2tt_tensor(ffn_params['lin1.weight']) + w_2 = convert2tt_tensor(ffn_params['lin2.weight']) + else: + w_1 = convert2tt_tensor( + torch.clone(torch.t(ffn_params['lin1.weight']).contiguous())) + w_2 = convert2tt_tensor( + torch.clone(torch.t(ffn_params['lin2.weight']).contiguous())) + + with torch.no_grad(): + ffn = DistrillFFN(w_1, convert2tt_tensor(ffn_params['lin1.bias']), + w_2, convert2tt_tensor(ffn_params['lin2.bias']), + convert2tt_tensor(layernorm_params['weight']), + convert2tt_tensor(layernorm_params['bias'])) + return ffn + + +class DistrillTransformerBlock: + def __init__(self, attn: DistillBertAttention, ffn: DistrillFFN): + self.attention = attn + self.ffn = ffn + + def __call__(self, + hidden_states: AnyTensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions=False, + return_type: Optional[ReturnType] = None): + hidden_states = try_convert(hidden_states) + + sa_output = self.attention(hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + return_type=ReturnType.turbo_transformers) + if output_attentions: + sa_output, sa_weights = sa_output + else: + sa_output = sa_output[0] + ffn_output = self.ffn(sa_output) + output = (ffn_output, ) + if output_attentions: + output = (sa_weights, ) + output + return output + + @staticmethod + def from_torch(layer: TorchDistilTransformerBlock): + return DistrillTransformerBlock( + DistillBertAttention.from_torch(layer.attention, + layer.sa_layer_norm), + DistrillFFN.from_torch(layer.ffn, layer.output_layer_norm)) + + +class DistrillTransformer: + def __init__(self, blocks: Sequence[DistrillTransformerBlock]): + self.blocks = blocks + + def __call__(self, + hidden_states: AnyTensor, + attention_mask: Optional[AnyTensor] = None, + head_mask: Optional[AnyTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_type: Optional[ReturnType] = ReturnType.TORCH): + all_hidden_states = () + all_attentions = () + hidden_states = try_convert(hidden_states) + for l in self.blocks: + layer_outputs = l(hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + return_type=ReturnType.turbo_transformers) + if output_hidden_states: + all_hidden_states = all_hidden_states + ( + convert_returns_as_type(hidden_states, ReturnType.TORCH), ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + # outputs = (convert_returns_as_type(hidden_states, return_type), ) + outputs = (hidden_states, ) + # Add last layer + if output_hidden_states: + # TODO(jiaruifang)two return value use the same memory space, that is not supported in dlpack. + # So we do not append the last hidden_state at the buttom of all_hidden_states, + # User should use outputs[0] if necessary + # all_hidden_states = all_hidden_states + (convert_returns_as_type(hidden_states, ReturnType.TORCH),) + pass + + if output_hidden_states: + outputs = outputs + (all_hidden_states, ) + if output_attentions: + outputs = outputs + (all_attentions, ) + + return outputs + + @staticmethod + def from_torch(transform: TorchDistilTransformer): + blocks = [ + DistrillTransformerBlock.from_torch(l) for l in transform.layer + ] + return DistrillTransformer(blocks) + + +class DistilBertModel: + def __init__(self, + embeddings_onnxmodel_variant, + transformer: DistrillTransformer, + backend="turbo"): + if backend == "turbo": + self.embeddings = embeddings_onnxmodel_variant + self.transformer = transformer + self.backend = "turbo" + elif backend == "onnxrt": + self.onnxmodel = embeddings_onnxmodel_variant + self.backend = "onnxrt" + + def __call__(self, + input_ids: AnyTensor, + attention_masks: Optional[AnyTensor] = None, + token_type_ids: Optional[AnyTensor] = None, + position_ids: Optional[AnyTensor] = None, + head_mask: Optional[AnyTensor] = None, + inputs_embeds: Optional[AnyTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_type: Optional[ReturnType] = None): + if self.backend == "onnxrt": + if attention_masks is None: + attention_masks = np.ones(input_ids.size(), dtype=np.int64) + else: + attention_masks = attention_masks.cpu().numpy() + data = [input_ids.cpu().numpy(), attention_masks] + outputs = self.onnxmodel.run(inputs=data) + for idx, item in enumerate(outputs): + outputs[idx] = torch.tensor(item, device=input_ids.device) + return outputs + elif self.backend == "turbo": + # torch part + inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim) + inputs_embeds = try_convert(inputs_embeds) + + # turbo part + transformer_outputs = self.transformer( + hidden_states=inputs_embeds, + attention_mask=attention_masks, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_type=return_type) + return transformer_outputs + + @staticmethod + def from_torch(model: TorchDistilBertModel, backend="turbo"): + """ + :param model: a torch distrilBert Model + backend: turbo or onnxrt + move model to gpu before call this function. + """ + if backend == "turbo": + transformer = DistrillTransformer.from_torch(model.transformer) + return DistilBertModel(model.embeddings, transformer, "turbo") + elif backend == "onnxrt": + import onnx + import onnxruntime.backend + device = model.device + if 'cuda' in device.type and torch.cuda.is_available(): + use_gpu = True + else: + use_gpu = False + inputs = { + 'input_ids': + torch.randint(32, [2, 32], dtype=torch.long).to( + device), # list of numerical ids for the tokenised text + 'attention_mask': + torch.ones([2, 32], + dtype=torch.long).to(device), # dummy list of ones + } + onnx_model_path = "/tmp/temp_turbo_onnx.model" + with open(onnx_model_path, 'wb') as outf: + torch.onnx.export( + model=model, + args=(inputs['input_ids'], inputs['attention_mask'] + ), # model input (or a tuple for multiple inputs) + f=outf, + input_names=['input_ids', 'attention_mask'], + output_names=['output'], + dynamic_axes={ + 'input_ids': [0, 1], + 'attention_mask': [0, 1] + }) + onnx_model = onnx.load_model(f=onnx_model_path) + onnx_model = onnxruntime.backend.prepare( + model=onnx_model, + device='GPU' if use_gpu else "CPU", + graph_optimization_level=onnxruntime.GraphOptimizationLevel. + ORT_ENABLE_ALL) + return DistilBertModel(onnx_model, None, "onnxrt")