Skip to content

Commit

Permalink
cherry-pick: Port most changes from main
Browse files Browse the repository at this point in the history
- Excluded all changes to `docs` and `.github` directories; did include
documentation changes and all other commits, with the exception of #2451
and #2445 for reasons discussed
- Made necessary changes to switch over to Torch 2.2.0 rc builds,
including updating imports
  • Loading branch information
gs-olive committed Jan 4, 2024
1 parent b6dd22b commit c82408c
Show file tree
Hide file tree
Showing 88 changed files with 4,220 additions and 685 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/install-torch-tensorrt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
set -eou pipefail
# Source conda so it's available to the script environment
source ${BUILD_ENV_FILE}
${CONDA_RUN} ${PIP_INSTALL_TORCH} torch==2.1.2 torchvision==0.16.2 pyyaml
${CONDA_RUN} ${PIP_INSTALL_TORCH} torch==2.2.0 torchvision==0.17.0 pyyaml
export TRT_VERSION=$(${CONDA_RUN} python -c "import versions; versions.tensorrt_version()")
${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} tensorrt-bindings~=${TRT_VERSION} --extra-index-url=https://pypi.ngc.nvidia.com

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.

- Bazel 6.2.1
- Libtorch 2.1.1
- Libtorch 2.2.0
- CUDA 12.1
- cuDNN 8.9.5
- TensorRT 8.6.1
Expand Down
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ http_archive(
name = "libtorch",
build_file = "@//third_party/libtorch:BUILD",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu121.zip"],
urls = ["https://download.pytorch.org/libtorch/test/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip"],
)

http_archive(
name = "libtorch_pre_cxx11_abi",
build_file = "@//third_party/libtorch:BUILD",
strip_prefix = "libtorch",
urls = ["https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.1.1%2Bcu121.zip"],
urls = ["https://download.pytorch.org/libtorch/test/cu121/libtorch-shared-with-deps-2.2.0%2Bcu121.zip"],
)

# Download these tarballs manually from the NVIDIA website
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ void ConversionCtx::RecordNewITensor(const torch::jit::Value* value, nvinfer1::I

std::string ConversionCtx::SerializeEngine() {
#if NV_TENSORRT_MAJOR > 7
auto serialized_network = builder->buildSerializedNetwork(*net, *cfg);
auto serialized_network = make_trt(builder->buildSerializedNetwork(*net, *cfg));
if (!serialized_network) {
TORCHTRT_THROW_ERROR("Building serialized network failed in TensorRT");
}
Expand Down
147 changes: 86 additions & 61 deletions core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,74 @@ namespace converters {
namespace impl {
namespace {

void add_output_padding(nvinfer1::Dims& padding, nvinfer1::Dims& out_padding, bool& has_output_padding) {
int nbSpatialDims = out_padding.nbDims;
// When there is out_padding, if padding is larger than out_padding, just adjust padding Or reduce out_padding as
// minimum as possible.
for (int i = 0; i < nbSpatialDims; ++i) {
if (padding.d[i] - out_padding.d[i] >= 0) {
padding.d[i] -= out_padding.d[i];
out_padding.d[i] = 0;
} else {
// Reduce out_padding as possible.
out_padding.d[i] -= padding.d[i];
padding.d[i] = 0;
has_output_padding = true;
}
}
}

nvinfer1::ILayer* add_bias_layer(
ConversionCtx* ctx,
nvinfer1::ITensor* input_tensor,
nvinfer1::Dims& input_dims,
nvinfer1::Dims& output_padding,
Weights& bias) {
nvinfer1::ITensor* input_shape = ctx->net->addShape(*input_tensor)->getOutput(0);
// Add padding layer
nvinfer1::ITensor* start;
nvinfer1::ITensor* totalPadding;
auto in_nbDims = input_dims.nbDims;
std::vector<int32_t> startVec(in_nbDims, 0);
std::vector<int32_t> totalPaddingVec(in_nbDims, 0);
int32_t diff = in_nbDims - output_padding.nbDims;
for (int32_t i = diff; i < in_nbDims; i++) {
int32_t idx = i - diff;
startVec[i] = 0; // Don't need begin padding, only post padding
totalPaddingVec[i] = output_padding.d[idx];
}
start = tensor_to_const(ctx, torch::tensor(startVec, torch::kInt32));
totalPadding = tensor_to_const(ctx, torch::tensor(totalPaddingVec, torch::kInt32));

const auto size =
ctx->net->addElementWise(*input_shape, *totalPadding, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);

nvinfer1::Dims stride;
stride.nbDims = in_nbDims;
for (int64_t i = 0; i < in_nbDims; i++) {
stride.d[i] = 1;
}
const auto& dummy = stride;
auto* sliceLayer = ctx->net->addSlice(*input_tensor, dummy, dummy, stride);
sliceLayer->setInput(1, *start);
sliceLayer->setInput(2, *size);
sliceLayer->setMode(nvinfer1::SliceMode::kFILL);
nvinfer1::ITensor* slice_output = sliceLayer->getOutput(0);

nvinfer1::Dims constantDims;
constantDims.nbDims = in_nbDims;
for (int64_t i = 0; i < in_nbDims; i++) {
constantDims.d[i] = 1;
}
constantDims.d[diff - 1] =
bias.shape.d[0]; // Set C dimension to bias dim and other dimensions to 1 to enable broadcast
auto const_layer = ctx->net->addConstant(constantDims, bias.data);
auto bias_layer =
ctx->net->addElementWise(*slice_output, *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUM);

return bias_layer;
}

bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
// Input to conv/deconv
auto in = args[0].ITensor();
Expand Down Expand Up @@ -76,16 +144,29 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)

nvinfer1::ILayer* layer = nullptr;
if (transposed) {
nvinfer1::IDeconvolutionLayer* deconvLayer =
ctx->net->addDeconvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
// Fix padding based on output_padding provided
nvinfer1::Dims begPadding = padding;
bool hasOutputPadding = false;
add_output_padding(padding, out_padding, hasOutputPadding);

nvinfer1::IDeconvolutionLayer* deconvLayer = ctx->net->addDeconvolutionNd(
*in, kernel_dims.d[0], filter_dim, kernel_weights, hasOutputPadding ? nvinfer1::Weights{} : bias.data);
deconvLayer->setStrideNd(stride);
deconvLayer->setDilationNd(dilation);
deconvLayer->setNbGroups(groups);
deconvLayer->setPaddingNd(padding);
deconvLayer->setPrePadding(begPadding);
deconvLayer->setPostPadding(padding);

// Set deconv kernel weights
deconvLayer->setInput(1, *kernel);
TORCHTRT_CHECK(deconvLayer, "Unable to create deconv layer with non-const weights from node: " << *n);
layer = deconvLayer;
if (hasOutputPadding) {
LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding);
nvinfer1::ITensor* tensorPtr = deconvLayer->getOutput(0);
auto dims = in->getDimensions();
layer = add_bias_layer(ctx, tensorPtr, dims, out_padding, bias);
}
} else {
nvinfer1::IConvolutionLayer* convLayer =
ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
Expand Down Expand Up @@ -155,20 +236,7 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
// https://github.com/onnx/onnx-tensorrt/blob/c3cfcbc8248c6bd007e6630af2085df5e4834b42/builtin_op_importers.cpp#L734
nvinfer1::Dims begPadding = padding;
bool hasOutputPadding = false;
int nbSpatialDims = out_padding.nbDims;
// When there is out_padding, if padding is larger than out_padding, just adjust padding Or reduce out_padding as
// minimum as possible.
for (int i = 0; i < nbSpatialDims; ++i) {
if (padding.d[i] - out_padding.d[i] >= 0) {
padding.d[i] -= out_padding.d[i];
out_padding.d[i] = 0;
} else {
// Reduce out_padding as possible.
out_padding.d[i] -= padding.d[i];
padding.d[i] = 0;
hasOutputPadding = true;
}
}
add_output_padding(padding, out_padding, hasOutputPadding);

// shape of deconvolution's weight: [in, out/groups, ...]
// If there is still output padding, remove the bias. Bias will be added below.
Expand All @@ -190,51 +258,8 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
#endif
if (hasOutputPadding) {
LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding);

// Add padding layer
nvinfer1::ITensor* start;
nvinfer1::ITensor* totalPadding;
auto in_nbDims = orig_dims.nbDims;
std::vector<int32_t> startVec(in_nbDims, 0);
std::vector<int32_t> totalPaddingVec(in_nbDims, 0);
int32_t diff = in_nbDims - out_padding.nbDims;
for (int32_t i = diff; i < in_nbDims; i++) {
int32_t idx = i - diff;
startVec[i] = 0; // Don't need begin padding, only post padding
totalPaddingVec[i] = out_padding.d[idx];
}
start = tensor_to_const(ctx, torch::tensor(startVec, torch::kInt32));
totalPadding = tensor_to_const(ctx, torch::tensor(totalPaddingVec, torch::kInt32));

nvinfer1::ITensor* tensorPtr = deconv->getOutput(0);
nvinfer1::ITensor* deconvOutShape = ctx->net->addShape(*tensorPtr)->getOutput(0);
const auto size =
ctx->net->addElementWise(*deconvOutShape, *totalPadding, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);

nvinfer1::Dims stride;
stride.nbDims = in_nbDims;
for (int64_t i = 0; i < in_nbDims; i++) {
stride.d[i] = 1;
}
const auto& dummy = stride;
auto* sliceLayer = ctx->net->addSlice(*tensorPtr, dummy, dummy, stride);
sliceLayer->setInput(1, *start);
sliceLayer->setInput(2, *size);
sliceLayer->setMode(nvinfer1::SliceMode::kFILL);
tensorPtr = sliceLayer->getOutput(0);

nvinfer1::Dims constantDims;
constantDims.nbDims = in_nbDims;
for (int64_t i = 0; i < in_nbDims; i++) {
constantDims.d[i] = 1;
}
constantDims.d[diff - 1] =
bias.shape.d[0]; // Set C dimension to bias dim and other dimensions to 1 to enable broadcast
auto const_layer = ctx->net->addConstant(constantDims, bias.data);
auto add_bias_layer =
ctx->net->addElementWise(*tensorPtr, *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUM);

new_layer = add_bias_layer;
new_layer = add_bias_layer(ctx, tensorPtr, orig_dims, out_padding, bias);
} else {
new_layer = deconv;
}
Expand Down
39 changes: 34 additions & 5 deletions core/conversion/converters/impl/matrix_multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,49 @@ auto mm_registrations TORCHTRT_UNUSED =
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);

auto selfDims = self->getDimensions().nbDims;
auto otherDims = other->getDimensions().nbDims;

bool squeezeFront = false;
bool squeezeBack = false;

if (selfDims == 1 && selfDims < otherDims) {
squeezeFront = true;
} else if (otherDims == 1 && otherDims < selfDims) {
// Append a 1 to the end of the shape before padding front to match self
other = addPadding(ctx, n, other, 2, true, false);
otherDims = other->getDimensions().nbDims;
squeezeBack = true;
}

// Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if
// necessary.
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
self = addPadding(ctx, n, self, other->getDimensions().nbDims, false, false);
} else {
other = addPadding(ctx, n, other, self->getDimensions().nbDims, false, false);
if (selfDims < otherDims) {
self = addPadding(ctx, n, self, otherDims, false, false);
} else if (otherDims < selfDims) {
other = addPadding(ctx, n, other, selfDims, false, false);
}

auto mm_layer = ctx->net->addMatrixMultiply(
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);

TORCHTRT_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
mm_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
auto out = mm_layer->getOutput(0);

if (squeezeFront || squeezeBack) {
auto squeezeDimOffset = squeezeFront ? 2 : 1;
auto reshapeDims =
util::squeezeDims(out->getDimensions(), out->getDimensions().nbDims - squeezeDimOffset);
auto shuffle_layer = ctx->net->addShuffle(*out);
LOG_DEBUG("Squeezing matmul output for 1d correction: " << reshapeDims);
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
shuffle_layer->setReshapeDimensions(reshapeDims);
shuffle_layer->setName((util::node_info(n) + "_squeeze").c_str());
out = shuffle_layer->getOutput(0);
}
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out);

LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
Expand Down
2 changes: 1 addition & 1 deletion core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def_pickle(
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
// Serialize TensorRT engine
auto serialized_trt_engine = self->cuda_engine->serialize();
auto serialized_trt_engine = make_trt(self->cuda_engine->serialize());

// Adding device info related meta data to the serialized file
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/torch_tensorrt/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#define STR(x) XSTR(x)

#define TORCH_TENSORRT_MAJOR_VERSION 2
#define TORCH_TENSORRT_MINOR_VERSION 1
#define TORCH_TENSORRT_MINOR_VERSION 2
#define TORCH_TENSORRT_PATCH_VERSION 0
#define TORCH_TENSORRT_VERSION \
STR(TORCH_TENSORRT_MAJOR_VERSION) \
Expand Down
1 change: 1 addition & 0 deletions dev_dep_versions.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__version__: "2.2.0"
__cuda_version__: "12.1"
__cudnn_version__: "8.9"
__tensorrt_version__: "8.6"
4 changes: 2 additions & 2 deletions docker/dist-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
TOP_DIR=$(cd $(dirname $0); pwd)/..

if [[ -z "${USE_CXX11}" ]]; then
BUILD_CMD="python -m pip wheel . --extra-index-url https://download.pytorch.org/whl/cu121 -w dist"
BUILD_CMD="python -m pip wheel . --extra-index-url https://download.pytorch.org/whl/test/cu121 -w dist"
else
BUILD_CMD="python -m pip wheel . --config-setting="--build-option=--use-cxx11-abi" --extra-index-url https://download.pytorch.org/whl/cu121 -w dist"
BUILD_CMD="python -m pip wheel . --config-setting="--build-option=--use-cxx11-abi" --extra-index-url https://download.pytorch.org/whl/test/cu121 -w dist"
fi

# TensorRT restricts our pip version
Expand Down
3 changes: 2 additions & 1 deletion docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ User Guide
:maxdepth: 1
:hidden:


user_guide/dynamic_shapes
user_guide/ptq
user_guide/saving_models
Expand Down Expand Up @@ -206,4 +207,4 @@ Legacy Further Information (TorchScript)
* `GTC 2021 Fall Talk <https://www.nvidia.com/en-us/on-demand/session/gtcfall21-a31107/>`_
* `PyTorch Ecosystem Day 2021 <https://assets.pytorch.org/pted2021/posters/I6.png>`_
* `PyTorch Developer Conference 2021 <https://s3.amazonaws.com/assets.pytorch.org/ptdd2021/posters/D2.png>`_
* `PyTorch Developer Conference 2022 <https://pytorch.s3.amazonaws.com/posters/ptc2022/C04.pdf>`_
* `PyTorch Developer Conference 2022 <https://pytorch.s3.amazonaws.com/posters/ptc2022/C04.pdf>`_
17 changes: 8 additions & 9 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ Saving models compiled with Torch-TensorRT
:members:
:undoc-members:
:show-inheritance:

Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation.

Dynamo IR
Dynamo IR
-------------

Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based.
Expand Down Expand Up @@ -41,7 +41,7 @@ The following code illustrates this approach.
b) ExportedProgram
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant
`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant
`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk.

.. code-block:: python
Expand All @@ -51,18 +51,17 @@ b) ExportedProgram
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
exp_program = torch_tensorrt.dynamo.trace(model, inputs)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) # Output is a torch.fx.GraphModule
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
# Transform and create an exported program
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs, exp_program.call_spec, ir="exported_program")
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
torch.export.save(trt_exp_program, "trt_model.ep")
# Later, you can load it and run inference
# Later, you can load it and run inference
model = torch.export.load("trt_model.ep")
model(*inputs)
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes, stiches all the nodes together and creates an ExportedProgram.
This is needed as `torch.export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).

.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341

Expand Down
4 changes: 2 additions & 2 deletions examples/int8/training/vgg16/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
tensorboard>=1.14.0
protobuf==3.20.*
nvidia-pyindex
--extra-index-url https://pypi.ngc.nvidia.com
pytorch-quantization>=2.1.2
--extra-index-url https://pypi.nvidia.com
pytorch-quantization
tqdm
Loading

0 comments on commit c82408c

Please sign in to comment.