Skip to content

Commit

Permalink
chore: update perf tooling to add dynamo options (#2423)
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
Co-authored-by: George S <[email protected]>
  • Loading branch information
peri044 and gs-olive authored Nov 1, 2023
1 parent 21cf0d0 commit 6266443
Show file tree
Hide file tree
Showing 11 changed files with 444 additions and 561 deletions.
1 change: 1 addition & 0 deletions .github/workflows/docker_builder.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
branches:
- main
- nightly
- release/2.1

# If pushes to main are made in rapid succession,
# cancel existing docker builds and use newer commits
Expand Down
92 changes: 14 additions & 78 deletions tools/perf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
This is a comprehensive Python benchmark suite to run perf runs using different supported backends. Following backends are supported:

1. Torch
2. Torch-TensorRT
3. FX-TRT
4. TensorRT
2. Torch-TensorRT [Torchscript]
3. Torch-TensorRT [Dynamo]
4. Torch-TensorRT [torch_compile]
5. TensorRT


Note: Please note that for ONNX models, user can convert the ONNX model to TensorRT serialized engine and then use this package.
Expand All @@ -22,9 +23,6 @@ Benchmark scripts depends on following Python packages in addition to requiremen

```
./
├── config
│ ├── vgg16_trt.yml
│ └── vgg16.yml
├── models
├── perf_run.py
├── hub.py
Expand All @@ -35,87 +33,20 @@ Benchmark scripts depends on following Python packages in addition to requiremen
```



* `config` - Directory which contains sample yaml configuration files for VGG network.
* `models` - Model directory
* `perf_run.py` - Performance benchmarking script which supports torch, torch_tensorrt, fx2trt, tensorrt backends
* `perf_run.py` - Performance benchmarking script which supports torch, ts_trt, torch_compile, dynamo, tensorrt backends
* `hub.py` - Script to download torchscript models for VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT
* `custom_models.py` - Script which includes custom models other than torchvision and timm (eg: HF BERT)
* `utils.py` - utility functions script
* `benchmark.sh` - This is used for internal performance testing of VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT.

## Usage

There are two ways you can run a performance benchmark.

### Using YAML config files

To run the benchmark for a given configuration file:

```python
python perf_run.py --config=config/vgg16.yml
```

There are two sample configuration files added.

* vgg16.yml demonstrates a configuration with all the supported backends (Torch, Torch-TensorRT, TensorRT)
* vgg16_trt.yml demonstrates how to use an external TensorRT serialized engine file directly.


### Supported fields

| Name | Supported Values | Description |
| ----------------- | ------------------------------------ | ------------------------------------------------------------ |
| backend | all, torchscript, fx2trt, torch, torch_tensorrt, tensorrt | Supported backends for inference. "all" implies the last four methods in the list at left, and "torchscript" implies the last three (excludes fx path) |
| input | - | Input binding names. Expected to list shapes of each input bindings |
| model | - | Configure the model filename and name |
| model_torch | - | Name of torch model file and name (used for fx2trt) (optional) |
| filename | - | Model file name to load from disk. |
| name | - | Model name |
| runtime | - | Runtime configurations |
| device | 0 | Target device ID to run inference. Range depends on available GPUs |
| precision | fp32, fp16 or half, int8 | Target precision to run inference. int8 cannot be used with 'all' backend |
| calibration_cache | - | Calibration cache file expected for torch_tensorrt runtime in int8 precision |

Additional sample use case:

```
backend:
- torch
- torch_tensorrt
- tensorrt
- fx2trt
input:
input0:
- 3
- 224
- 224
num_inputs: 1
model:
filename: model.plan
name: vgg16
model_torch:
filename: model_torch.pt
name: vgg16
runtime:
device: 0
precision:
- fp32
- fp16
```

Note:

1. Please note that measuring INT8 performance is only supported via a `calibration cache` file or QAT mode for `torch_tensorrt` backend.
2. TensorRT engine filename should end with `.plan` otherwise it will be treated as Torchscript module.

### Using CompileSpec options via CLI

Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module

* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt,tensorrt,fx2trt
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `fx2trt`, the input should be a Pytorch module (instead of a torchscript module) and the options for model are (`vgg16` | `resnet50` | `efficientnet_b0`)
* `--model_torch` : Name of the PyTorch model file (optional, only necessary if fx2trt is a chosen backend)
* `--backends` : Comma separated string of backends. Eg: torch, torch_compile, dynamo, tensorrt
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module).
* `--model_torch` : Name of the PyTorch model file (optional, only necessary if `dynamo` or `torch_compile` is a chosen backend)
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
* `--batch_size` : Batch size
* `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16
Expand All @@ -131,10 +62,15 @@ Eg:
--model_torch ${MODELS_DIR}/vgg16_torch.pt \
--precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \
--batch_size 1 \
--backends torch,torch_tensorrt,tensorrt,fx2trt \
--backends torch,ts_trt,dynamo,torch_compile,tensorrt \
--report "vgg_perf_bs1.txt"
```

Note:

1. Please note that measuring INT8 performance is only supported via a `calibration cache` file or QAT mode for `torch_tensorrt` backend.
2. TensorRT engine filename should end with `.plan` otherwise it will be treated as Torchscript module.

### Example models

This tool benchmarks any pytorch model or torchscript module. As an example, we provide VGG16, Resnet50, EfficientNet-B0, VIT, HF-BERT models in `hub.py` that we internally test for performance.
Expand Down
70 changes: 58 additions & 12 deletions tools/perf/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,62 +6,108 @@ MODELS_DIR="models"
python hub.py

batch_sizes=(1 2 4 8 16 32 64 128 256)
large_model_batch_sizes=(1 2 4 8 16 32 64)

#Benchmark VGG16 model

# Benchmark VGG16 model
echo "Benchmarking VGG16 model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
--model_torch ${MODELS_DIR}/vgg16_pytorch.pt \
--model_torch vgg16 \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--batch_size ${bs} \
--backends torch,torch_tensorrt,tensorrt,fx2trt \
--report "vgg_perf_bs${bs}.txt"
--truncate \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "vgg16_perf_bs${bs}.txt"
done

# Benchmark AlexNet model
echo "Benchmarking AlexNet model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/alexnet_scripted.jit.pt \
--model_torch alexnet \
--precision fp32,fp16 --inputs="(${bs}, 3, 227, 227)" \
--batch_size ${bs} \
--truncate \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "alexnet_perf_bs${bs}.txt"
done

# Benchmark Resnet50 model
echo "Benchmarking Resnet50 model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/resnet50_scripted.jit.pt \
--model_torch ${MODELS_DIR}/resnet50_pytorch.pt \
--model_torch resnet50 \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--batch_size ${bs} \
--backends torch,torch_tensorrt,tensorrt,fx2trt \
--report "rn50_perf_bs${bs}.txt"
--truncate \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "resnet50_perf_bs${bs}.txt"
done

# Benchmark VIT model
echo "Benchmarking VIT model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/vit_scripted.jit.pt \
--model_torch vit \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--batch_size ${bs} \
--backends torch,torch_tensorrt,tensorrt \
--truncate \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "vit_perf_bs${bs}.txt"
done

# Benchmark VIT Large model
echo "Benchmarking VIT Large model"
for bs in ${large_model_batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/vit_large_scripted.jit.pt \
--model_torch vit_large \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--truncate \
--batch_size ${bs} \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "vit_large_perf_bs${bs}.txt"
done

# Benchmark EfficientNet-B0 model
echo "Benchmarking EfficientNet-B0 model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/efficientnet_b0_scripted.jit.pt \
--model_torch ${MODELS_DIR}/efficientnet_b0_pytorch.pt \
--model_torch efficientnet_b0 \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--batch_size ${bs} \
--backends torch,torch_tensorrt,tensorrt,fx2trt \
--report "eff_b0_perf_bs${bs}.txt"
--truncate \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "efficientnet_b0_perf_bs${bs}.txt"
done

# Benchmark Stable Diffusion UNet model
echo "Benchmarking SD UNet model"
for bs in ${large_model_batch_sizes[@]}
do
python perf_run.py --model_torch sd_unet \
--precision fp32,fp16 --inputs="(${bs}, 4, 128, 128)@fp16;(${bs})@fp16;(${bs}, 1, 768)@fp16" \
--batch_size ${bs} \
--backends torch,dynamo,torch_compile,inductor \
--truncate \
--report "sd_unet_perf_bs${bs}.txt"
done

# Benchmark BERT model
echo "Benchmarking Huggingface BERT base model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/bert_base_uncased_traced.jit.pt \
--model_torch "bert_base_uncased" \
--precision fp32 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \
--batch_size ${bs} \
--backends torch,torch_tensorrt \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--truncate \
--report "bert_base_perf_bs${bs}.txt"
done
Expand Down
19 changes: 0 additions & 19 deletions tools/perf/config/vgg16.yml

This file was deleted.

20 changes: 0 additions & 20 deletions tools/perf/config/vgg16_trt.yml

This file was deleted.

35 changes: 20 additions & 15 deletions tools/perf/custom_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer, BertConfig
import torch.nn.functional as F


def BertModule():
from transformers import BertModel

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name, torchscript=True)
model.eval()
return model


def BertInputs():
from transformers import BertTokenizer

model_name = "bert-base-uncased"
enc = BertTokenizer.from_pretrained(model_name)
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
Expand All @@ -15,16 +23,13 @@ def BertModule():
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
config = BertConfig(
vocab_size_or_config_json_file=32000,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
torchscript=True,
return [tokens_tensor, segments_tensors]


def StableDiffusionUnet():
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
)
model = BertModel(config)
model.eval()
model = BertModel.from_pretrained(model_name, torchscript=True)
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
return traced_model
return pipe.unet
Loading

0 comments on commit 6266443

Please sign in to comment.