Navigate to | Part 3: Optimizing Triton Configuration | Part 5: Building Model Ensembles |
---|
Model acceleration is a complex nuanced topic. The viability of techniques like graph optimizations for models, pruning, knowledge distillation, quantization, and more, highly depend on the structure of the model. Each of these topics are vast fields of research in their own right and building custom tools requires massive engineering investment.
Rather than having an exhaustive outline of the ecosystem, for brevity and objectivity, this discussion will be focused on the tools and features which are recommended to use while deploying models using the Triton Inference Server.
Triton Inference Server has a concept called "Triton Backend" or "Backend". Backends are essentially the implementation that executes the model. A backend can be a wrapper around a popular deep learning framework like PyTorch, TensorFlow, TensorRT or ONNX Runtime, or users can choose to build their own backends customized to their models and use case. Each of these backends have their own specific options for acceleration.
Performance tuning of Triton models is discussed broadly here, but this doc will go into more detail below.
Acceleration recommendations depend on two main factors:
- Type of Hardware: Triton users can choose to run models on GPU or CPU. Owing to the parallelism they provide, GPUs provide many avenues of performance acceleration. Models using PyTorch, TensorFlow, ONNX runtime, and TensorRT can utilize these benefits. For CPUs Triton users can leverage the OpenVINO backend for acceleration.
- Type of the model: Usually users leverage one or more of three different classes of model:
Shallow models
like Random Forests,Neural Networks
like BERT or CNNs, and lastly,Large Transformer Models
which are usually too big to fit in a single GPU's memory. Each model category leverages different optimizations to accelerate performance.
With these broad categories considered, let's drill down into the specific scenarios and decision making process to pick the most appropriate Triton Backend for the use case along with a brief discussion about possible optimizations.
As mentioned before, acceleration for deep learning models can be achieved in many ways. Graph level optimizations like fusing layers can reduce the number of GPU kernels that are needed to be launched for execution. Fusing layers makes the model execution more memory efficient and increases the density of operations. Once fused, a kernel auto tuner can pick the correct combination of kernels to maximize utilization of GPU resources. Similarly, use of lower precision (FP16, INT8, etc.) with techniques like quantization can drastically reduce memory requirements and increase throughput.
The exact nature of performance optimization tactics differs with each GPU based on its hardware design. These are a few of many challenges we solve for Deep Learning Practitioners with NVIDIA TensorRT which is an SDK focused on deep learning inference optimization.
While TensorRT works with popular deep learning frameworks like PyTorch, TensorFlow, MxNET, ONNX Runtime and more, it also has framework level integrations with PyTorch(Torch-TensorRT) and TensorFlow(TensorFlow-TensorRT) to provide their respective developers with flexibility and fallback mechanisms.
There are three routes for users to use to convert their models to TensorRT: the C++ API, Python API, and trtexec/polygraphy (TensorRT's command line tools). Refer this guide for a fleshed out example.
That said, there are two main steps needed. First, convert the model to a TensorRT Engine. It is recommended to use the TensorRT Container to run the command.
trtexec --onnx=model.onnx \
--saveEngine=model.plan \
--explicitBatch
Once converted, place the model in the model.plan
in the model repository (as described in part 1) and use tensorrt
as the backend
in the config.pbtxt
Apart from just the conversion to TensorRT, users can also leverage some cuda specific optimizations.
For cases where users run into a situation where some of the operators in their models aren't supported by TensorRT there are three possible options:
-
Use one of the framework integrations: TensorRT has two integrations with Frameworks: Torch-TensorRT (PyTorch), and TensorFlow-TensorRT (TensorFlow). These integrations have a fallback mechanism built in to use the framework backend in cases where TensorRT doesn't directly support the graph.
-
Use the ONNX Runtime with TensorRT: Triton users can also leverage this fallback mechanism with the ONNX Runtime (more in the following section).
-
Build a plugin: TensorRT allows for building plugins and implementing custom ops. Users can write their own TensorRT plugins to implement unsupported ops(Recommended for expert users). It is highly encouraged to report said ops to have them innately supported by TensorRT.
In the case of PyTorch, Torch-TensorRT is an Ahead of Time Compiler which converts TorchScript/Torch FX to a module targeting a TensorRT Engine. Post compilation, users can use the optimized model in the same manner as they would use a TorchScript model. Check out the getting started with Torch TensorRT to learn more. Refer this guide for a fleshed out example demonstrating compilation of PyTorch model with Torch TensorRT and deploying it on Triton.
TensorFlow users can make use of TensorFlow TensorRT, which segments the graph into subgraphs which are supported and not supported by TensorRT. The supported subgraphs are then replaced by a TensorRT optimized node producing a graph which has both TensorFlow and TensorRT components. Refer to this tutorial explaining the exact steps required to accelerate a model with TensorFlow-TensorRT and deploy it on Triton Inference Server.
There are three options to accelerate the ONNX runtime: with TensorRT
and CUDA
execution providers for GPU and with OpenVINO
(discussed in later section) for CPU.
In general TensorRT will provide better optimizations than the CUDA execution provider however, this depends on the exact structure of the model, more precisely, it depends in the operators used in the network being accelerated. If all the operators are supported, conversion to TensorRT will yield better performance. When TensorRT
is selected as the accelerator, all supported subgraphs are accelerated by TensorRT and the rest of the graph runs on the CUDA execution provider. Users can achieve this with the following additions to the config file.
TensorRT acceleration
optimization {
execution_accelerators {
gpu_execution_accelerator : [ {
name : "tensorrt"
parameters { key: "precision_mode" value: "FP16" }
parameters { key: "max_workspace_size_bytes" value: "1073741824" }
}]
}
}
That said, users can also choose to run models without TensorRT optimization, in which case the CUDA EP is the default Execution Provider. More details can be found here. Refer to the onnx_tensorrt_config.pbtxt
here for a sample configuration file for the Text Recognition
model used in Part 1-3 of this series.
There are a few other ONNX runtime specific optimizations. Refer to this section of our ONNX backend docs for more information.
Triton Inference Server also supports acceleration for CPU only model with OpenVINO. In configuration file, users can add the following to enable CPU acceleration.
optimization {
execution_accelerators {
cpu_execution_accelerator : [{
name : "openvino"
}]
}
}
While OpenVINO provides software level optimizations, it is also important to consider the CPU hardware being used. CPUs comprise multiple cores, memory resources, and interconnects. With multiple CPUs these resources can be shared with NUMA (Non uniform memory access). Refer this section of the Triton Documentation for more.
Shallow models like Gradient Boosted Decision Trees are often used in many pipelines. These models are typically built with libraries like XGBoost, LightGBM, Scikit-learn, cuML and more. These models can be deployed on the Triton Inference Server via the Forest Inference Library backend. Check out these examples for more information.
On the other end of the spectrum, Deep Learning practitioners are drawn to Large Transformer based models with billions of parameters. With models at that scale, often times they need different types of optimization or to be parallelized across GPUs. This parallelization across GPUs(as they may not fit on 1 GPU) can be achieved either via Tensor parallelism or Pipeline parallelism. To solve this issue, users can use the Faster Transformer Library and Triton's Faster Transformer Backend. Check out this blog for more information!
Before proceeding, please set up a model repository for the Text Recognition model being used in Part 1-3 of this series. Then, navigate to the model repository and launch two containers:
# Server Container
docker run --gpus=all -it --shm-size=256m --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd):/workspace/ -v/$(pwd)/model_repository:/models nvcr.io/nvidia/tritonserver:22.11-py3 bash
# Client Container (on a different terminal)
docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:22.11-py3-sdk bash
Since this is a model we converted to ONNX, and TensorRT acceleration examples are linked throughout the explanation, we will explore the ONNX pathway. There are three cases to consider with ONNX backend:
- Accelerated ONNX RT execution on GPU w. CUDA execution provider:
ORT_cuda_ep_config.pbtxt
- ONNX RT execution on GPU w. TRT acceleration:
ORT_TRT_config.pbtxt
- ONNX RT execution on CPU w. OpenVINO acceleration:
ORT_openvino_config.pbtxt
While using ONNX RT there are some general optimizations to consider, irrespective of the Execution provider. These can be graph level optimizations, or selecting the number and behavior of the threads used to parallelize the execution or some memory usage optimizations. The use of each of these options is highly dependent on the model being deployed.
With this context, let's launch the Triton Inference Server with the appropriate configuration file.
tritonserver --model-repository=/models
NOTE: These benchmarks are just to illustrate the general curve of the performance gain. This is not the highest throughput obtainable via Triton as resource utilization features haven't been enabled (eg. Dynamic Batching). Refer to the Model Analyzer tutorial for the best deployment configuration once model optimization are done.
NOTE: These settings are to maximize throughput. Refer to the Model Analyzer tutorial which covers managing latency requirements.
For reference, the baseline performance is as follows:
Inferences/Second vs. Client Average Batch Latency
Concurrency: 2, throughput: 4191.7 infer/sec, latency 7633 usec
For this model, an exhaustive search for the best convolution algorithm is enabled. Learn about more options.
## Additions to Config
parameters { key: "cudnn_conv_algo_search" value: { string_value: "0" } }
parameters { key: "gpu_mem_limit" value: { string_value: "4294967200" } }
## Perf Analyzer Query
perf_analyzer -m text_recognition -b 16 --shape input.1:1,32,100 --concurrency-range 64
...
Inferences/Second vs. Client Average Batch Latency
Concurrency: 2, throughput: 4257.9 infer/sec, latency 7672 usec
While specifying the use of TensorRT Execution Provider, the CUDA Execution provider is used as a fallback for operators not supported by TensorRT. It is recommended to use TensorRT natively if all operators are supported as the performance boost and optimization options are considerably better. In this case, TensorRT accelerator has been used with lower FP16
precision.
## Additions to Config
optimization {
graph : {
level : 1
}
execution_accelerators {
gpu_execution_accelerator : [ {
name : "tensorrt",
parameters { key: "precision_mode" value: "FP16" },
parameters { key: "max_workspace_size_bytes" value: "1073741824" }
}]
}
}
## Perf Analyzer Query
perf_analyzer -m text_recognition -b 16 --shape input.1:1,32,100 --concurrency-range 2
...
Inferences/Second vs. Client Average Batch Latency
Concurrency: 2, throughput: 11820.2 infer/sec, latency 2706 usec
Triton users can also use OpenVINO for CPU deployment. This can be enabled via the following:
optimization { execution_accelerators {
cpu_execution_accelerator : [ {
name : "openvino"
} ]
}}
As compare 1 CPU with 1 GPU is not an apples to apples comparison for most cases, we encourage benchmarking on user's local CPU hardware. Learn more
There are many other features that for each backend which can be enabled depending on the needs of specific models. Refer to this protobuf for the complete list of possible features and optimizations.
The sections above describe converting models and using different accelerators and provide a "general guideline" to build an intuition about which "path" to take while considering optimizations. These are manual explorations that consume considerable time. To check the conversion coverage and explore a subset of the optimization possible, users can make use of the Model Navigator Tool.
In this tutorial, we covered a plethora of optimization options available to accelerate models while using the Triton Inference Server. This is Part 4 of a 6 part tutorial series which covers the challenges faced in deploying Deep Learning models to production. Part 5 covers Building a model ensemble
. Part 3 and Part 4 focus on two different aspects, resource utilizations and framework level model acceleration respectively. Using both of these techniques in conjunction will lead to the best performance possible. Since the specific selections are highly dependent on workloads, models, SLAs, and hardware resources, this process varies for each user. We highly encourage users to experiment with all these features to find our the best deployment configuration for their use case.