Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

❓ [Question] Why BERT Base is slower w/ Torch-TensorRT than native PyTorch? #830

Closed
void-main opened this issue Jan 26, 2022 · 9 comments
Assignees
Labels
No Activity performance question Further information is requested

Comments

@void-main
Copy link

❓ Question

I'm trying to optimize hugging face's BERT Base uncased model using Torch-TensorRT, the code works after disabling full compilation (require_full_compilation=False), and the avg latency is ~10ms on T4. However, it it slower than native PyTorch implementation (~6ms on T4). In contrast, running the same model with trtexec only takes ~4ms. So, for BERT Base, it's 2.5x slower than TensorRT. I wonder if this is expected?

Here's the full code:

from transformers import BertModel, BertTokenizer, BertConfig
import torch
import time

enc = BertTokenizer.from_pretrained("./bert-base-uncased")

# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens]).to(torch.int32).cuda()
segments_tensors = torch.tensor([segments_ids]).to(torch.int32).cuda()

dummy_input = [tokens_tensor, segments_tensors]
dummy_input_shapes = [list(v.size()) for v in dummy_input]

# Initializing the model with the torchscript flag
# Flag set to True even though it is not necessary as this model does not have an LM Head.
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
    num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True)

# Instantiating the model
model = BertModel(config)

# The model needs to be in evaluation mode
model.eval()

# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("./bert-base-uncased", torchscript=True)

model = model.eval().cuda()

# Creating the trace
traced_model = torch.jit.trace(model, dummy_input)

import torch_tensorrt
compile_settings = {
    "require_full_compilation": False,
    "truncate_long_and_double": True,
    "torch_executed_ops": ["aten::Int"]
}
optimized_model = torch_tensorrt.compile(traced_model, inputs=dummy_input, **compile_settings)

def benchmark(model, input):
    # Warming up
    for _ in range(10):
        model(*input)

    inference_count = 1000
    # inference test
    start = time.time()
    for _ in range(inference_count):
        model(*input)
    end = time.time()
    print(f"use {(end-start)/inference_count*1000} ms each inference")
    print(f"{inference_count/(end-start)} step/s")

print("before compile")
benchmark(traced_model, dummy_input)

print("after compile")
benchmark(optimized_model, dummy_input)

So, my question is why it is slower than native PyTorch, and how do I fine-tune it?

What you have already tried

I've checked out the log from Torch-TensorRT, looks like the model is partitioned into 3 parts, separated by at::Int op, and looks like Int op is hard to implement.

Next, I profiled the inference process with Nsight System, here's the screenshot:
CleanShot 2022-01-26 at 18 44 38

It is expected to see 3 divided segments, however, there are 2 things that caught my attention:

  1. Why segment 0 is slower than pure TensorRT? Is it due to over complicated conversion?
  2. Why the cudaMemcpyAsync took so long? Shouldn't it only return the last_hidden_state tensor?

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • PyTorch Version (e.g., 1.0): 1.10
  • CPU Architecture:
  • OS (e.g., Linux): Ubuntu 18.04
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): python setup.py develop
  • Are you using local sources or building from archives: local sources
  • Python version: 3.6.9
  • CUDA version: 10.2
  • GPU models and configuration: T4
  • Any other relevant information:

Additional context

@void-main void-main added the question Further information is requested label Jan 26, 2022
@peri044
Copy link
Collaborator

peri044 commented Jan 31, 2022

End to end compilation:
at::Int is a bit tricky to implement and we are looking into it. This op just converts an integer scalar to an integer tensor which can be redundant for static sequence length tensors. But a workaround for this would be to avoid having at::Int in the BERT torchscript model itself. This can be done by regenerating the BERT model by making a one line change here

seq_length = input_shape[1]

to

seq_length = int(input_shape[1])

Once you make this change, please compile and re-install transformers. Regenerate a fresh copy of BERT model using exporter https://huggingface.co/docs/transformers/serialization#using-torchscript-in-python

With this change, you should be able to convert the entire BERT model into TensorRT engine (you can verify this by setting require_full_compilation=True) . Performance on T4 should be improved and better than Pytorch as well.

@narendasan
Copy link
Collaborator

Just to note, you don't necessarily need to reinstall transformers, you can just patch this in your installed library since its just a change in the python code

@void-main
Copy link
Author

Hi @peri044 @narendasan , thank you so much for the suggestions and sorry for my late reply.

I edited modeling_bert.py and set require_full_compilation=True to make sure the whole graph is transformed to TensorRT. However, the time is still much worse than native implementation.

Could you please share your running code / configuration so that I could try to figure out what's wrong with my code?

Here's my code:

from transformers import BertModel, BertTokenizer, BertConfig
import torch
import time

enc = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens]).to(torch.int32).cuda()
segments_tensors = torch.tensor([segments_ids]).to(torch.int32).cuda()

dummy_input = [tokens_tensor, segments_tensors]
dummy_input_shapes = [list(v.size()) for v in dummy_input]

# Initializing the model with the torchscript flag
# Flag set to True even though it is not necessary as this model does not have an LM Head.
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
    num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True)

# Instantiating the model
model = BertModel(config)

# The model needs to be in evaluation mode
model.eval()

# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

model = model.eval().cuda()

# Creating the trace
traced_model = torch.jit.trace(model, dummy_input)

import torch_tensorrt
compile_settings = {
    "require_full_compilation": True,
    # "require_full_compilation": False,
    "truncate_long_and_double": True,
    # "torch_executed_ops": ["aten::Int"]
}
optimized_model = torch_tensorrt.compile(traced_model, inputs=dummy_input, **compile_settings)

def benchmark(model, input):
    # Warming up
    for _ in range(10):
        model(*input)

    inference_count = 1000
    # inference test
    start = time.time()
    for _ in range(inference_count):
        model(*input)
    end = time.time()
    print(f"use {(end-start)/inference_count*1000} ms each inference")
    print(f"{inference_count/(end-start)} step/s")

print("before compile")
benchmark(traced_model, dummy_input)

print("after compile")
benchmark(optimized_model, dummy_input)

@void-main
Copy link
Author

Also, I got a lot of warnings during the inference, but these warnings seem to be ignorable:

WARNING: [Torch-TensorRT TorchScript Conversion Context] - Tensor DataType is determined at build time for tensors not marked as input or output.
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size

@void-main
Copy link
Author

I'm using the latest commit:

>>> import torch_tensorrt
>>> torch_tensorrt.__version__
'1.1.0a0+4fd886d0'

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@andi4191
Copy link
Contributor

andi4191 commented Jun 1, 2022

Hi @void-main,

I tried your script ( #830 (comment)) on master branch and observed following logs:

before compile use 7.4726243019104 ms each inference 133.82179534227978 step/s after compile use 2.0044188499450684 ms each inference 498.8977229122572 step/s

I used a Titan V card on my host machine. Are you saying that after compile is resulting in higher latency than before compile at your end?

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@Liujingxiu23
Copy link

Liujingxiu23 commented Nov 9, 2023

@andi4191 I also meet "slower" problem.
torch: 2.0.1+cuda117
tensorrt:8.6.1
torch_tensorrt:1.4.0

import os
import time
import warnings
import torch
from torch import nn
import torch_tensorrt

transformer = nn.TransformerEncoderLayer(512, 8).eval().cuda()
input_ = torch.rand((10, 1, 512), dtype=torch.float).to("cuda")
inputs = [input_]

for i in range(10):
logits = transformer(input_)

start_time = time.time()
for i in range(10000):
logits = transformer(input_)
end_time = time.time()
print("using ", (end_time-start_time)*1000, "ms")

compilation_kwargs = {
"enabled_precisions": {torch.int8},
"debug": True,
"workspace_size": 20 << 30,
"min_block_size": 7,
"torch_executed_ops": {},
}

optimized_model = torch.compile(
transformer,
backend="torch_tensorrt",
options=compilation_kwargs,
)
optimized_model(*inputs)

for i in range(10):
logits = optimized_model(*inputs)

start_time = time.time()
for i in range(10000):
logits = optimized_model(*inputs)
end_time = time.time()
print("tensrrt using ", (end_time-start_time)*1000, "ms")

input_ = torch.rand((10, 10, 512), dtype=torch.float).to("cuda")
inputs = [input_]
outputs = optimized_model(*inputs)
print("outputs:", outputs.shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
No Activity performance question Further information is requested
Projects
None yet
Development

No branches or pull requests

6 participants