Skip to content

Commit

Permalink
Develop (#195)
Browse files Browse the repository at this point in the history
Add distilBERT
  • Loading branch information
feifeibear authored Nov 24, 2020
1 parent 0641b54 commit 68b8f72
Show file tree
Hide file tree
Showing 21 changed files with 1,031 additions and 17 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_FLAGS "-Wall")
set(CMAKE_C_FLAGS "-Wall")

set(TURBO_TRANSFORMERS_VERSION 0.5.0)

set(TURBO_TRANSFORMERS_VERSION 0.5.1)


option(WITH_PROFILER "Compile with profiler" OFF)
Expand Down
3 changes: 0 additions & 3 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@
--enable_mem_opt Use model aware memory optimization for BERT.
"""

import json
import os

import docopt
from turbo_benchmark_helper import benchmark_turbo_transformers
from torch_benchmark_helper import benchmark_torch
Expand Down
4 changes: 3 additions & 1 deletion benchmark/onnx_benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def generate_onnx_model(model_name: str,
use_dynamic_axes: bool = False):
import transformers
import torch
import os

test_device = torch.device(
'cuda:0') if backend == "GPU" and use_gpu else torch.device('cpu:0')
Expand All @@ -45,6 +44,9 @@ def generate_onnx_model(model_name: str,
elif model_name == "roberta":
cfg = transformers.RobertaConfig()
model = transformers.RobertaModel(cfg)
elif model_name == "distilbert":
cfg = transformers.DistilBertConfig()
model = transformers.DistilBertModel(cfg)
else:
raise (f"benchmark does not support {model_name}")

Expand Down
7 changes: 3 additions & 4 deletions benchmark/torch_benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ def benchmark_torch(model_name: str, seq_len: int, batch_size: int, n: int,
import benchmark_helper

test_device = torch.device('cuda:0') if use_gpu else torch.device('cpu:0')
if use_gpu:
print("using GPU")
else:
print("using CPU")
torch.set_grad_enabled(False)
torch.set_num_threads(num_threads)

Expand All @@ -39,6 +35,9 @@ def benchmark_torch(model_name: str, seq_len: int, batch_size: int, n: int,
elif model_name == "roberta":
cfg = transformers.RobertaConfig()
model = transformers.RobertaModel(cfg)
elif model_name == "distilbert":
cfg = transformers.DistilBertConfig()
model = transformers.DistilBertModel(cfg)
else:
raise (f"benchmark does not support {model_name}")
model.eval()
Expand Down
10 changes: 6 additions & 4 deletions benchmark/turbo_benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ def benchmark_turbo_transformers(model_name: str, seq_len: int,
import turbo_transformers
import benchmark_helper
test_device = torch.device('cuda:0') if use_gpu else torch.device('cpu:0')
if use_gpu:
print("using GPU")
else:
print("using CPU")
cfg = None
torch.set_grad_enabled(False)
if model_name == "bert":
Expand All @@ -48,6 +44,12 @@ def benchmark_turbo_transformers(model_name: str, seq_len: int,
model.to(test_device)
model.eval()
model = turbo_transformers.RobertaModel.from_torch(model)
elif model_name == "distilbert":
cfg = transformers.DistilBertConfig()
model = transformers.DistilBertModel(cfg)
model.to(test_device)
model.eval()
model = turbo_transformers.DistilBertModel.from_torch(model)
else:
raise (f"benchmark does not support {model_name}")

Expand Down
Empty file added distrill/bert_model.txt
Empty file.
Empty file added distrill/distill_bert.txt
Empty file.
45 changes: 45 additions & 0 deletions distrill/distrill_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2020 THL A29 Limited, a Tencent company.
# All rights reserved.
# Licensed under the BSD 3-Clause License (the "License"); you may
# not use this file except in compliance with the License. You may
# obtain a copy of the License at
# https://opensource.org/licenses/BSD-3-Clause
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" basis,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# See the AUTHORS file for names of contributors.

from transformers import DistilBertTokenizer, DistilBertModel
from transformers import BertTokenizer, BertModel
import torch

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
# inputs = torch.randint(low=0,
# high=cfg.vocab_size - 1,
# size=(1, 10),
# dtype=torch.long,
# device=torch.device("cpu:0"))

## distrillation model
model = DistilBertModel.from_pretrained("distilbert-base-uncased",
return_dict=True)

## bert model
bert_model = BertModel.from_pretrained("bert-base-uncased", return_dict=True)

cfg = model.config
print(cfg)
print(inputs)
outputs = model(**inputs)
bert_outputs = bert_model(**inputs)

print(model)
print(bert_model)

# print(bert_outputs - outputs)
#
# last_hidden_states = outputs.last_hidden_state
# print(last_hidden_states)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
contexttimer
onnx
future
transformers==3.0.2
transformers==3.4.0
2 changes: 1 addition & 1 deletion tools/docker/Dockerfile_release.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ RUN /opt/conda/bin/conda install pytorch==1.5.0 cpuonly -c pytorch && \
/opt/conda/bin/conda install make cmake git graphviz gperftools git-lfs docopt -c conda-forge && \
/opt/conda/bin/conda clean -afy

RUN pip --no-cache-dir install contexttimer future transformers==3.0.2 docopt onnxruntime-tools
RUN pip --no-cache-dir install contexttimer future transformers==3.4.0 docopt onnxruntime-tools
WORKDIR /workspace
4 changes: 3 additions & 1 deletion tools/docker/Dockerfile_release.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ RUN curl -LO https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-
conda install pytorch=PYTORCH_VERSION cudatoolkit=CUDA_VERSION cudnn --freeze-installed -c pytorch && \
conda clean -yfa

RUN pip --no-cache-dir install contexttimer future transformers==3.0.2 docopt OpenNMT-py==1.1.0 onnxruntime-gpu==1.3.0

RUN pip --no-cache-dir install contexttimer future transformers==3.4.0 docopt OpenNMT-py==1.1.0 onnxruntime-gpu==1.3.0


COPY --from=DEV_IMAGE /opt/miniconda3/lib/python3.7/site-packages/turbo_transformers /opt/miniconda3/lib/python3.7/site-packages/turbo_transformers

Expand Down
48 changes: 48 additions & 0 deletions turbo_transformers/layers/positionwise_ffn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,53 @@ void PositionwiseFeedForward::operator()(const core::Tensor& input_tensor,

void PositionwiseFeedForward::EnforceShapeAndType() const {}

void DistrillFFN::operator()(const core::Tensor& input_tensor,
core::Tensor* output_tensor,
bool is_trans_weight) const {
auto d_ff =
is_trans_weight ? dense_weight_1_.shape(0) : dense_weight_1_.shape(1);

auto model_dim_weight =
is_trans_weight ? dense_weight_1_.shape(1) : dense_weight_1_.shape(0);
auto model_dim = input_tensor.shape(2);

TT_ENFORCE_EQ(
model_dim_weight, model_dim,
"dense weight and input tensor should have the same model_dim.");

auto devType = input_tensor.device_type();
auto devId = input_tensor.device_id();

// input tensor size (batch_size, input_len, model_dim)
auto batch_size = input_tensor.shape(0);
auto input_len = input_tensor.shape(1);
// allocate memory for temp data
core::Tensor input_tensor_copy(nullptr);
input_tensor_copy.Reshape<float>({batch_size, input_len, model_dim}, devType,
devId);
core::Tensor temp_tensor(nullptr);
temp_tensor.Reshape<float>({batch_size * input_len, d_ff}, devType, devId);

// start computation
core::Copy<float>(input_tensor, input_tensor_copy, "FFN/AddInputBias");

output_tensor->Reshape<float>({batch_size, input_len, model_dim}, devType,
devId, "FFN/Reshape");
kernels::MatMul(input_tensor_copy, false, dense_weight_1_, is_trans_weight,
1.0, // input (b*seq, model) X dense_weight_1_ (model_dim,
// d_ff) -> temp_tensor (B*seq, d_ff)
&temp_tensor, 0.0, "FFN/gemm0");
kernels::AddBiasAct<float, types::ActivationType::Gelu>(
dense_bias_1_, &temp_tensor, "FFN/AddBiasAct");
kernels::MatMul(temp_tensor, false, dense_weight_2_, is_trans_weight, 1.0,
&input_tensor_copy, 0.0, "FFN/gemm1");
kernels::AddInputBias(input_tensor, input_tensor_copy, dense_bias_2_,
output_tensor, "FFN/AddInputBias");
kernels::LayerNorm<float>(layer_norm_weight_, layer_norm_bias_, output_tensor,
1e-12, "FFN/LayerNorm");
}

void DistrillFFN::EnforceShapeAndType() const {}

} // namespace layers
} // namespace turbo_transformers
30 changes: 30 additions & 0 deletions turbo_transformers/layers/positionwise_ffn.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <utility>

#include "turbo_transformers/core/tensor.h"

namespace turbo_transformers {
Expand Down Expand Up @@ -51,5 +52,34 @@ class PositionwiseFeedForward {
core::Tensor layer_norm_bias_;
};

class DistrillFFN {
public:
DistrillFFN(core::Tensor dense_weight_1, core::Tensor dense_bias_1,
core::Tensor dense_weight_2, core::Tensor dense_bias_2,
core::Tensor layer_norm_weight, core::Tensor layer_norm_bias)
: dense_weight_1_(std::move(dense_weight_1)),
dense_bias_1_(std::move(dense_bias_1)),
dense_weight_2_(std::move(dense_weight_2)),
dense_bias_2_(std::move(dense_bias_2)),
layer_norm_weight_(std::move(layer_norm_weight)),
layer_norm_bias_(std::move(layer_norm_bias)) {
EnforceShapeAndType();
}
void EnforceShapeAndType() const;

// according to profiling results on Intel 61xx, is_trans_weight = true is
// faster
void operator()(const core::Tensor &input_tensor, core::Tensor *output,
bool is_trans_weight = true) const;

private:
core::Tensor dense_weight_1_;
core::Tensor dense_bias_1_;
core::Tensor dense_weight_2_;
core::Tensor dense_bias_2_;
core::Tensor layer_norm_weight_;
core::Tensor layer_norm_bias_;
};

} // namespace layers
} // namespace turbo_transformers
12 changes: 12 additions & 0 deletions turbo_transformers/python/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,18 @@ PYBIND11_MODULE(turbo_transformers_cxx, m) {
}))
.def("__call__", &layers::PositionwiseFeedForward::operator());

py::class_<layers::DistrillFFN>(m, "DistrillFFN")
.def(py::init([](core::Tensor &dense_weight_1, core::Tensor &dense_bias_1,
core::Tensor &dense_weight_2, core::Tensor &dense_bias_2,
core::Tensor &layer_norm_weight,
core::Tensor &layer_norm_bias) -> layers::DistrillFFN * {
return new layers::DistrillFFN(
std::move(dense_weight_1), std::move(dense_bias_1),
std::move(dense_weight_2), std::move(dense_bias_2),
std::move(layer_norm_weight), std::move(layer_norm_bias));
}))
.def("__call__", &layers::DistrillFFN::operator());

py::class_<layers::FusedAddBiasGELU>(m, "FusedAddBiasGELU")
.def(py::init([](core::Tensor &dense_bias) -> layers::FusedAddBiasGELU * {
return new layers::FusedAddBiasGELU(std::move(dense_bias));
Expand Down
117 changes: 117 additions & 0 deletions turbo_transformers/python/tests/distill_bert_attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (C) 2020 THL A29 Limited, a Tencent company.
# All rights reserved.
# Licensed under the BSD 3-Clause License (the "License"); you may
# not use this file except in compliance with the License. You may
# obtain a copy of the License at
# https://opensource.org/licenses/BSD-3-Clause
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" basis,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# See the AUTHORS file for names of contributors.

import turbo_transformers

import unittest
import sys
import torch
from transformers.modeling_distilbert import DistilBertConfig
from transformers.modeling_distilbert import MultiHeadSelfAttention as DistilAttention
from torch import nn

import os
sys.path.append(os.path.dirname(__file__))
import test_helper

fname = "tt_distrill_attention.txt"


def create_test(batch_size, seq_length):
class TestDistillBertAttention(unittest.TestCase):
def init_data(self, use_cuda):
test_device = torch.device('cuda:0') if use_cuda else \
torch.device('cpu:0')
if not use_cuda:
torch.set_num_threads(4)
turbo_transformers.set_num_threads(4)

torch.set_grad_enabled(False)
self.cfg = DistilBertConfig(attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0)
self.cfg.output_attentions = True
self.torch_attention = DistilAttention(self.cfg)
self.torch_sa_layer_norm = nn.LayerNorm(
normalized_shape=self.cfg.dim, eps=1e-12)
self.torch_attention.eval()
self.torch_sa_layer_norm.eval()
if use_cuda:
self.torch_attention.to(test_device)
self.torch_sa_layer_norm.to(test_device)

# Get FT Attention
self.turbo_attention = turbo_transformers.DistillBertAttention.from_torch(
self.torch_attention, self.torch_sa_layer_norm)

hidden_size = self.cfg.hidden_size
self.input_tensor = torch.rand(size=(batch_size, seq_length,
hidden_size),
dtype=torch.float32,
device=test_device)
# NOTE, the mask of distilled attention is different from huggingface bert attention.
self.attention_mask = torch.ones((batch_size, seq_length),
dtype=torch.float32,
device=test_device)

def check_torch_and_turbo(self, use_cuda, num_iter=1):
self.init_data(use_cuda)
device = "GPU" if use_cuda else "CPU"
torch_model = lambda: self.torch_sa_layer_norm(
self.torch_attention(query=self.input_tensor,
key=self.input_tensor,
value=self.input_tensor,
mask=self.attention_mask,
output_attentions=False)[0] + self.
input_tensor)
torch_attention_result, torch_qps, torch_time_consume = \
test_helper.run_model(torch_model, use_cuda, num_iter, use_profile=False)
print(
f"DistilAttention+LN \"({batch_size},{seq_length:03})\" ",
f"{device} Torch QPS, {torch_qps}, time, {torch_time_consume}")

turbo_model = lambda: self.turbo_attention(
self.input_tensor,
self.attention_mask,
output_attentions=self.cfg.output_attentions)[0]

turbo_attention_result, turbo_qps, turbo_time_consume = \
test_helper.run_model(turbo_model, use_cuda,
num_iter)
print(
f"DistilAttention \"({batch_size},{seq_length:03})\" ",
f" {device} Turbo QPS, {turbo_qps}, time, {turbo_time_consume}"
)

self.assertTrue(
torch.max(
torch.abs(torch_attention_result - turbo_attention_result))
< (1e-3 if use_cuda else 1e-4))

def test_distillbert_attention(self):
self.check_torch_and_turbo(use_cuda=False, num_iter=1)
if torch.cuda.is_available() and \
turbo_transformers.config.is_compiled_with_cuda():
self.check_torch_and_turbo(use_cuda=True, num_iter=1)

globals(
)[f"TestDistillBertAtt{batch_size}_{seq_length:3}"] = TestDistillBertAttention


with open(fname, "w") as fh:
fh.write(", torch, turbo_transformers\n")
for batch_size in [1, 2]:
for seq_length in [10, 20, 128]:
create_test(batch_size, seq_length)

if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 68b8f72

Please sign in to comment.