From f1fa0b0ae67e6b818313a41269d237e2f15d87b3 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 3 Jan 2024 12:26:37 -0800 Subject: [PATCH] cherry-pick: Port most changes from `main` - Excluded all changes to `docs` and `.github` directories; did include documentation changes and all other commits, with the exception of #2451 and #2445 for reasons discussed - Made necessary changes to switch over to Torch 2.2.0 rc builds, including updating imports --- README.md | 2 +- WORKSPACE | 4 +- .../conversionctx/ConversionCtx.cpp | 2 +- .../converters/impl/conv_deconv.cpp | 147 ++-- .../converters/impl/matrix_multiply.cpp | 39 +- core/runtime/register_jit_hooks.cpp | 2 +- cpp/include/torch_tensorrt/macros.h | 2 +- dev_dep_versions.yml | 1 + docker/dist-build.sh | 4 +- docsrc/index.rst | 3 +- docsrc/user_guide/saving_models.rst | 17 +- examples/int8/training/vgg16/requirements.txt | 4 +- noxfile.py | 185 ++--- py/requirements.txt | 3 +- py/torch_tensorrt/_Input.py | 7 +- py/torch_tensorrt/dynamo/_compiler.py | 34 +- py/torch_tensorrt/dynamo/_exporter.py | 57 +- py/torch_tensorrt/dynamo/_settings.py | 7 +- py/torch_tensorrt/dynamo/_tracer.py | 91 +-- .../dynamo/conversion/_ConverterRegistry.py | 25 + .../dynamo/conversion/aten_ops_converters.py | 677 ++++++++++++++++-- .../dynamo/conversion/converter_utils.py | 3 +- .../dynamo/conversion/impl/__init__.py | 6 +- .../dynamo/conversion/impl/activation/ops.py | 30 - .../dynamo/conversion/impl/addmm.py | 34 + .../dynamo/conversion/impl/cat.py | 11 +- .../dynamo/conversion/impl/conv.py | 7 +- .../dynamo/conversion/impl/deconv.py | 9 +- .../conversion/impl/elementwise/base.py | 8 +- .../dynamo/conversion/impl/elementwise/ops.py | 174 +++-- .../dynamo/conversion/impl/embedding.py | 137 +++- .../dynamo/conversion/impl/grid.py | 45 ++ .../dynamo/conversion/impl/pad.py | 205 ++++++ .../dynamo/conversion/impl/reduce.py | 31 +- .../dynamo/conversion/impl/select.py | 109 +-- .../dynamo/conversion/impl/slice/ops.py | 30 +- .../dynamo/conversion/impl/topk.py | 136 ++++ .../dynamo/conversion/impl/unary/ops.py | 42 +- .../dynamo/conversion/impl/upsample.py | 67 ++ .../dynamo/conversion/ops_evaluators.py | 12 + .../dynamo/lowering/_decompositions.py | 16 - .../lowering/passes/_aten_lowering_pass.py | 2 + .../lowering/passes/constant_folding.py | 2 +- .../passes/lower_efficient_attention.py | 34 +- .../dynamo/lowering/passes/lower_linear.py | 47 ++ .../partitioning/_adjacency_partitioner.py | 6 +- .../partitioning/_global_partitioner.py | 25 +- .../dynamo/tools/opset_coverage.py | 20 +- pyproject.toml | 6 +- .../converters/test_matrix_multiply.cpp | 132 +++- tests/modules/custom_models.py | 11 +- tests/modules/requirements.txt | 4 +- .../dynamo/backend/test_specialized_models.py | 34 + tests/py/dynamo/conversion/test_addmm_aten.py | 65 ++ tests/py/dynamo/conversion/test_amin_aten.py | 95 +++ .../py/dynamo/conversion/test_arange_aten.py | 34 + .../py/dynamo/conversion/test_argmax_aten.py | 8 +- .../py/dynamo/conversion/test_argmin_aten.py | 41 ++ .../conversion/test_bitwise_and_aten.py | 73 ++ .../conversion/test_bitwise_not_aten.py | 33 + .../dynamo/conversion/test_bitwise_or_aten.py | 73 ++ .../conversion/test_bitwise_xor_aten.py | 73 ++ tests/py/dynamo/conversion/test_casts.py | 15 + tests/py/dynamo/conversion/test_chunk_aten.py | 82 +++ tests/py/dynamo/conversion/test_clamp_aten.py | 26 +- tests/py/dynamo/conversion/test_clip_aten.py | 35 +- tests/py/dynamo/conversion/test_copy_aten.py | 31 + .../conversion/test_embedding_bag_aten.py | 141 ++++ tests/py/dynamo/conversion/test_eq_aten.py | 69 ++ tests/py/dynamo/conversion/test_ge_aten.py | 69 ++ tests/py/dynamo/conversion/test_grid_aten.py | 150 ++++ tests/py/dynamo/conversion/test_gt_aten.py | 66 ++ tests/py/dynamo/conversion/test_index_aten.py | 27 + tests/py/dynamo/conversion/test_le_aten.py | 69 ++ tests/py/dynamo/conversion/test_lt_aten.py | 66 ++ tests/py/dynamo/conversion/test_ne_aten.py | 69 ++ tests/py/dynamo/conversion/test_pad_aten.py | 241 +++++++ tests/py/dynamo/conversion/test_sort_aten.py | 34 + tests/py/dynamo/conversion/test_tile_aten.py | 75 ++ tests/py/dynamo/conversion/test_trunc_aten.py | 52 ++ tests/py/dynamo/conversion/test_upsample.py | 97 +++ .../lowering/test_aten_lowering_passes.py | 108 +++ .../py/dynamo/lowering/test_decompositions.py | 72 -- tests/py/dynamo/models/test_dyn_models.py | 4 +- tests/py/dynamo/models/test_export_serde.py | 56 +- .../WORKSPACE.x86_64.release.rhel.tmpl | 4 +- version.txt | 2 +- 87 files changed, 4219 insertions(+), 684 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/addmm.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/grid.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/pad.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/topk.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/upsample.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py create mode 100644 tests/py/dynamo/conversion/test_addmm_aten.py create mode 100644 tests/py/dynamo/conversion/test_amin_aten.py create mode 100644 tests/py/dynamo/conversion/test_arange_aten.py create mode 100644 tests/py/dynamo/conversion/test_argmin_aten.py create mode 100644 tests/py/dynamo/conversion/test_bitwise_and_aten.py create mode 100644 tests/py/dynamo/conversion/test_bitwise_not_aten.py create mode 100644 tests/py/dynamo/conversion/test_bitwise_or_aten.py create mode 100644 tests/py/dynamo/conversion/test_bitwise_xor_aten.py create mode 100644 tests/py/dynamo/conversion/test_chunk_aten.py create mode 100644 tests/py/dynamo/conversion/test_copy_aten.py create mode 100644 tests/py/dynamo/conversion/test_embedding_bag_aten.py create mode 100644 tests/py/dynamo/conversion/test_eq_aten.py create mode 100644 tests/py/dynamo/conversion/test_ge_aten.py create mode 100644 tests/py/dynamo/conversion/test_grid_aten.py create mode 100644 tests/py/dynamo/conversion/test_gt_aten.py create mode 100644 tests/py/dynamo/conversion/test_le_aten.py create mode 100644 tests/py/dynamo/conversion/test_lt_aten.py create mode 100644 tests/py/dynamo/conversion/test_ne_aten.py create mode 100644 tests/py/dynamo/conversion/test_pad_aten.py create mode 100644 tests/py/dynamo/conversion/test_sort_aten.py create mode 100644 tests/py/dynamo/conversion/test_tile_aten.py create mode 100644 tests/py/dynamo/conversion/test_trunc_aten.py create mode 100644 tests/py/dynamo/conversion/test_upsample.py diff --git a/README.md b/README.md index b3749ca609..aab97ee2d1 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass. - Bazel 6.2.1 -- Libtorch 2.1.1 +- Libtorch 2.2.0 - CUDA 12.1 - cuDNN 8.9.5 - TensorRT 8.6.1 diff --git a/WORKSPACE b/WORKSPACE index 73deb8506c..62e88ccd52 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -54,14 +54,14 @@ http_archive( name = "libtorch", build_file = "@//third_party/libtorch:BUILD", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu121.zip"], + urls = ["https://download.pytorch.org/libtorch/test/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip"], ) http_archive( name = "libtorch_pre_cxx11_abi", build_file = "@//third_party/libtorch:BUILD", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.1.1%2Bcu121.zip"], + urls = ["https://download.pytorch.org/libtorch/test/cu121/libtorch-shared-with-deps-2.2.0%2Bcu121.zip"], ) # Download these tarballs manually from the NVIDIA website diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index c9a76602c2..2eb363706f 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -164,7 +164,7 @@ void ConversionCtx::RecordNewITensor(const torch::jit::Value* value, nvinfer1::I std::string ConversionCtx::SerializeEngine() { #if NV_TENSORRT_MAJOR > 7 - auto serialized_network = builder->buildSerializedNetwork(*net, *cfg); + auto serialized_network = make_trt(builder->buildSerializedNetwork(*net, *cfg)); if (!serialized_network) { TORCHTRT_THROW_ERROR("Building serialized network failed in TensorRT"); } diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index fc0e97b7ee..083a4ecc2f 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -10,6 +10,74 @@ namespace converters { namespace impl { namespace { +void add_output_padding(nvinfer1::Dims& padding, nvinfer1::Dims& out_padding, bool& has_output_padding) { + int nbSpatialDims = out_padding.nbDims; + // When there is out_padding, if padding is larger than out_padding, just adjust padding Or reduce out_padding as + // minimum as possible. + for (int i = 0; i < nbSpatialDims; ++i) { + if (padding.d[i] - out_padding.d[i] >= 0) { + padding.d[i] -= out_padding.d[i]; + out_padding.d[i] = 0; + } else { + // Reduce out_padding as possible. + out_padding.d[i] -= padding.d[i]; + padding.d[i] = 0; + has_output_padding = true; + } + } +} + +nvinfer1::ILayer* add_bias_layer( + ConversionCtx* ctx, + nvinfer1::ITensor* input_tensor, + nvinfer1::Dims& input_dims, + nvinfer1::Dims& output_padding, + Weights& bias) { + nvinfer1::ITensor* input_shape = ctx->net->addShape(*input_tensor)->getOutput(0); + // Add padding layer + nvinfer1::ITensor* start; + nvinfer1::ITensor* totalPadding; + auto in_nbDims = input_dims.nbDims; + std::vector startVec(in_nbDims, 0); + std::vector totalPaddingVec(in_nbDims, 0); + int32_t diff = in_nbDims - output_padding.nbDims; + for (int32_t i = diff; i < in_nbDims; i++) { + int32_t idx = i - diff; + startVec[i] = 0; // Don't need begin padding, only post padding + totalPaddingVec[i] = output_padding.d[idx]; + } + start = tensor_to_const(ctx, torch::tensor(startVec, torch::kInt32)); + totalPadding = tensor_to_const(ctx, torch::tensor(totalPaddingVec, torch::kInt32)); + + const auto size = + ctx->net->addElementWise(*input_shape, *totalPadding, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); + + nvinfer1::Dims stride; + stride.nbDims = in_nbDims; + for (int64_t i = 0; i < in_nbDims; i++) { + stride.d[i] = 1; + } + const auto& dummy = stride; + auto* sliceLayer = ctx->net->addSlice(*input_tensor, dummy, dummy, stride); + sliceLayer->setInput(1, *start); + sliceLayer->setInput(2, *size); + sliceLayer->setMode(nvinfer1::SliceMode::kFILL); + nvinfer1::ITensor* slice_output = sliceLayer->getOutput(0); + + nvinfer1::Dims constantDims; + constantDims.nbDims = in_nbDims; + for (int64_t i = 0; i < in_nbDims; i++) { + constantDims.d[i] = 1; + } + constantDims.d[diff - 1] = + bias.shape.d[0]; // Set C dimension to bias dim and other dimensions to 1 to enable broadcast + auto const_layer = ctx->net->addConstant(constantDims, bias.data); + auto bias_layer = + ctx->net->addElementWise(*slice_output, *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUM); + + return bias_layer; +} + bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { // Input to conv/deconv auto in = args[0].ITensor(); @@ -76,16 +144,29 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) nvinfer1::ILayer* layer = nullptr; if (transposed) { - nvinfer1::IDeconvolutionLayer* deconvLayer = - ctx->net->addDeconvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data); + // Fix padding based on output_padding provided + nvinfer1::Dims begPadding = padding; + bool hasOutputPadding = false; + add_output_padding(padding, out_padding, hasOutputPadding); + + nvinfer1::IDeconvolutionLayer* deconvLayer = ctx->net->addDeconvolutionNd( + *in, kernel_dims.d[0], filter_dim, kernel_weights, hasOutputPadding ? nvinfer1::Weights{} : bias.data); deconvLayer->setStrideNd(stride); deconvLayer->setDilationNd(dilation); deconvLayer->setNbGroups(groups); - deconvLayer->setPaddingNd(padding); + deconvLayer->setPrePadding(begPadding); + deconvLayer->setPostPadding(padding); + // Set deconv kernel weights deconvLayer->setInput(1, *kernel); TORCHTRT_CHECK(deconvLayer, "Unable to create deconv layer with non-const weights from node: " << *n); layer = deconvLayer; + if (hasOutputPadding) { + LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); + nvinfer1::ITensor* tensorPtr = deconvLayer->getOutput(0); + auto dims = in->getDimensions(); + layer = add_bias_layer(ctx, tensorPtr, dims, out_padding, bias); + } } else { nvinfer1::IConvolutionLayer* convLayer = ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data); @@ -155,20 +236,7 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) // https://github.com/onnx/onnx-tensorrt/blob/c3cfcbc8248c6bd007e6630af2085df5e4834b42/builtin_op_importers.cpp#L734 nvinfer1::Dims begPadding = padding; bool hasOutputPadding = false; - int nbSpatialDims = out_padding.nbDims; - // When there is out_padding, if padding is larger than out_padding, just adjust padding Or reduce out_padding as - // minimum as possible. - for (int i = 0; i < nbSpatialDims; ++i) { - if (padding.d[i] - out_padding.d[i] >= 0) { - padding.d[i] -= out_padding.d[i]; - out_padding.d[i] = 0; - } else { - // Reduce out_padding as possible. - out_padding.d[i] -= padding.d[i]; - padding.d[i] = 0; - hasOutputPadding = true; - } - } + add_output_padding(padding, out_padding, hasOutputPadding); // shape of deconvolution's weight: [in, out/groups, ...] // If there is still output padding, remove the bias. Bias will be added below. @@ -190,51 +258,8 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) #endif if (hasOutputPadding) { LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding); - - // Add padding layer - nvinfer1::ITensor* start; - nvinfer1::ITensor* totalPadding; - auto in_nbDims = orig_dims.nbDims; - std::vector startVec(in_nbDims, 0); - std::vector totalPaddingVec(in_nbDims, 0); - int32_t diff = in_nbDims - out_padding.nbDims; - for (int32_t i = diff; i < in_nbDims; i++) { - int32_t idx = i - diff; - startVec[i] = 0; // Don't need begin padding, only post padding - totalPaddingVec[i] = out_padding.d[idx]; - } - start = tensor_to_const(ctx, torch::tensor(startVec, torch::kInt32)); - totalPadding = tensor_to_const(ctx, torch::tensor(totalPaddingVec, torch::kInt32)); - nvinfer1::ITensor* tensorPtr = deconv->getOutput(0); - nvinfer1::ITensor* deconvOutShape = ctx->net->addShape(*tensorPtr)->getOutput(0); - const auto size = - ctx->net->addElementWise(*deconvOutShape, *totalPadding, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - - nvinfer1::Dims stride; - stride.nbDims = in_nbDims; - for (int64_t i = 0; i < in_nbDims; i++) { - stride.d[i] = 1; - } - const auto& dummy = stride; - auto* sliceLayer = ctx->net->addSlice(*tensorPtr, dummy, dummy, stride); - sliceLayer->setInput(1, *start); - sliceLayer->setInput(2, *size); - sliceLayer->setMode(nvinfer1::SliceMode::kFILL); - tensorPtr = sliceLayer->getOutput(0); - - nvinfer1::Dims constantDims; - constantDims.nbDims = in_nbDims; - for (int64_t i = 0; i < in_nbDims; i++) { - constantDims.d[i] = 1; - } - constantDims.d[diff - 1] = - bias.shape.d[0]; // Set C dimension to bias dim and other dimensions to 1 to enable broadcast - auto const_layer = ctx->net->addConstant(constantDims, bias.data); - auto add_bias_layer = - ctx->net->addElementWise(*tensorPtr, *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUM); - - new_layer = add_bias_layer; + new_layer = add_bias_layer(ctx, tensorPtr, orig_dims, out_padding, bias); } else { new_layer = deconv; } diff --git a/core/conversion/converters/impl/matrix_multiply.cpp b/core/conversion/converters/impl/matrix_multiply.cpp index c4b12da810..90772ea8c4 100644 --- a/core/conversion/converters/impl/matrix_multiply.cpp +++ b/core/conversion/converters/impl/matrix_multiply.cpp @@ -16,12 +16,28 @@ auto mm_registrations TORCHTRT_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); + + auto selfDims = self->getDimensions().nbDims; + auto otherDims = other->getDimensions().nbDims; + + bool squeezeFront = false; + bool squeezeBack = false; + + if (selfDims == 1 && selfDims < otherDims) { + squeezeFront = true; + } else if (otherDims == 1 && otherDims < selfDims) { + // Append a 1 to the end of the shape before padding front to match self + other = addPadding(ctx, n, other, 2, true, false); + otherDims = other->getDimensions().nbDims; + squeezeBack = true; + } + // Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if // necessary. - if (self->getDimensions().nbDims < other->getDimensions().nbDims) { - self = addPadding(ctx, n, self, other->getDimensions().nbDims, false, false); - } else { - other = addPadding(ctx, n, other, self->getDimensions().nbDims, false, false); + if (selfDims < otherDims) { + self = addPadding(ctx, n, self, otherDims, false, false); + } else if (otherDims < selfDims) { + other = addPadding(ctx, n, other, selfDims, false, false); } auto mm_layer = ctx->net->addMatrixMultiply( @@ -29,7 +45,20 @@ auto mm_registrations TORCHTRT_UNUSED = TORCHTRT_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n); mm_layer->setName(util::node_info(n).c_str()); - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0)); + auto out = mm_layer->getOutput(0); + + if (squeezeFront || squeezeBack) { + auto squeezeDimOffset = squeezeFront ? 2 : 1; + auto reshapeDims = + util::squeezeDims(out->getDimensions(), out->getDimensions().nbDims - squeezeDimOffset); + auto shuffle_layer = ctx->net->addShuffle(*out); + LOG_DEBUG("Squeezing matmul output for 1d correction: " << reshapeDims); + TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); + shuffle_layer->setReshapeDimensions(reshapeDims); + shuffle_layer->setName((util::node_info(n) + "_squeeze").c_str()); + out = shuffle_layer->getOutput(0); + } + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 1acc27dda5..5ad0efb3b0 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -87,7 +87,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def_pickle( [](const c10::intrusive_ptr& self) -> std::vector { // Serialize TensorRT engine - auto serialized_trt_engine = self->cuda_engine->serialize(); + auto serialized_trt_engine = make_trt(self->cuda_engine->serialize()); // Adding device info related meta data to the serialized file auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size()); diff --git a/cpp/include/torch_tensorrt/macros.h b/cpp/include/torch_tensorrt/macros.h index d86733f291..1eb8c6feb5 100644 --- a/cpp/include/torch_tensorrt/macros.h +++ b/cpp/include/torch_tensorrt/macros.h @@ -24,7 +24,7 @@ #define STR(x) XSTR(x) #define TORCH_TENSORRT_MAJOR_VERSION 2 -#define TORCH_TENSORRT_MINOR_VERSION 1 +#define TORCH_TENSORRT_MINOR_VERSION 2 #define TORCH_TENSORRT_PATCH_VERSION 0 #define TORCH_TENSORRT_VERSION \ STR(TORCH_TENSORRT_MAJOR_VERSION) \ diff --git a/dev_dep_versions.yml b/dev_dep_versions.yml index cdf9b92de8..4cfbfe595a 100644 --- a/dev_dep_versions.yml +++ b/dev_dep_versions.yml @@ -1,3 +1,4 @@ +__version__: "2.2.0" __cuda_version__: "12.1" __cudnn_version__: "8.9" __tensorrt_version__: "8.6" diff --git a/docker/dist-build.sh b/docker/dist-build.sh index fab62f21cc..ba7dd83713 100755 --- a/docker/dist-build.sh +++ b/docker/dist-build.sh @@ -3,9 +3,9 @@ TOP_DIR=$(cd $(dirname $0); pwd)/.. if [[ -z "${USE_CXX11}" ]]; then - BUILD_CMD="python -m pip wheel . --extra-index-url https://download.pytorch.org/whl/cu121 -w dist" + BUILD_CMD="python -m pip wheel . --extra-index-url https://download.pytorch.org/whl/test/cu121 -w dist" else - BUILD_CMD="python -m pip wheel . --config-setting="--build-option=--use-cxx11-abi" --extra-index-url https://download.pytorch.org/whl/cu121 -w dist" + BUILD_CMD="python -m pip wheel . --config-setting="--build-option=--use-cxx11-abi" --extra-index-url https://download.pytorch.org/whl/test/cu121 -w dist" fi # TensorRT restricts our pip version diff --git a/docsrc/index.rst b/docsrc/index.rst index 75bd6983d7..455aeab8b3 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -87,6 +87,7 @@ User Guide :maxdepth: 1 :hidden: + user_guide/dynamic_shapes user_guide/ptq user_guide/saving_models @@ -206,4 +207,4 @@ Legacy Further Information (TorchScript) * `GTC 2021 Fall Talk `_ * `PyTorch Ecosystem Day 2021 `_ * `PyTorch Developer Conference 2021 `_ -* `PyTorch Developer Conference 2022 `_ \ No newline at end of file +* `PyTorch Developer Conference 2022 `_ diff --git a/docsrc/user_guide/saving_models.rst b/docsrc/user_guide/saving_models.rst index 9f3ca3eb27..6d890d0450 100644 --- a/docsrc/user_guide/saving_models.rst +++ b/docsrc/user_guide/saving_models.rst @@ -8,10 +8,10 @@ Saving models compiled with Torch-TensorRT :members: :undoc-members: :show-inheritance: - + Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation. -Dynamo IR +Dynamo IR ------------- Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based. @@ -41,7 +41,7 @@ The following code illustrates this approach. b) ExportedProgram ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant +`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant `torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk. .. code-block:: python @@ -51,18 +51,17 @@ b) ExportedProgram model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 224, 224)).cuda()] - exp_program = torch_tensorrt.dynamo.trace(model, inputs) - trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) # Output is a torch.fx.GraphModule + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule # Transform and create an exported program - trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs, exp_program.call_spec, ir="exported_program") + trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs) torch.export.save(trt_exp_program, "trt_model.ep") - # Later, you can load it and run inference + # Later, you can load it and run inference model = torch.export.load("trt_model.ep") model(*inputs) -`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes, stiches all the nodes together and creates an ExportedProgram. -This is needed as `torch.export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes). +`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together. +This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes). .. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341 diff --git a/examples/int8/training/vgg16/requirements.txt b/examples/int8/training/vgg16/requirements.txt index a56f0e40fe..d02af2c616 100644 --- a/examples/int8/training/vgg16/requirements.txt +++ b/examples/int8/training/vgg16/requirements.txt @@ -1,6 +1,6 @@ tensorboard>=1.14.0 protobuf==3.20.* nvidia-pyindex ---extra-index-url https://pypi.ngc.nvidia.com -pytorch-quantization>=2.1.2 +--extra-index-url https://pypi.nvidia.com +pytorch-quantization tqdm diff --git a/noxfile.py b/noxfile.py index 7a4da40ea3..40e3e8f28d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,7 +1,8 @@ -from distutils.command.clean import clean -import nox import os import sys +from distutils.command.clean import clean + +import nox # Use system installed Python packages PYT_PATH = ( @@ -203,11 +204,11 @@ def run_base_tests(session): session.run_always("pytest", test) -def run_fx_core_tests(session): - print("Running FX core tests") - session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) +def run_dynamo_backend_tests(session): + print("Running Dynamo core tests") + session.chdir(os.path.join(TOP_DIR, "tests/py/dynamo/")) tests = [ - "core", + "backend", ] for test in tests: if USE_HOST_DEPS: @@ -216,33 +217,23 @@ def run_fx_core_tests(session): session.run_always("pytest", test) -def run_fx_converter_tests(session): - print("Running FX converter tests") - session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) +def run_dynamo_converter_tests(session): + print("Running Dynamo converter tests") + session.chdir(os.path.join(TOP_DIR, "tests/py/dynamo/")) tests = [ - "converters", + "conversion", ] - # Skipping this test as it fails inside NGC container with the following error. - # Error Code 4: Internal Error (Could not find any implementation for node conv due to insufficient workspace. See verbose log for requested sizes.) - skip_tests = "-k not conv3d" for test in tests: if USE_HOST_DEPS: - session.run_always("pytest", test, skip_tests, env={"PYTHONPATH": PYT_PATH}) + session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH}) else: - session.run_always("pytest", test, skip_tests) + session.run_always("pytest", test) -def run_fx_lower_tests(session): - print("Running FX passes and trt_lower tests") - session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) - tests = [ - "passes/test_multi_fuse_trt.py", - # "passes/test_fuse_permute_linear_trt.py", - "passes/test_remove_duplicate_output_args.py", - "passes/test_fuse_permute_matmul_trt.py", - # "passes/test_graph_opts.py" - "trt_lower", - ] +def run_dynamo_lower_tests(session): + print("Running Dynamo lowering passes") + session.chdir(os.path.join(TOP_DIR, "tests/py/dynamo/")) + tests = ["lowering"] for test in tests: if USE_HOST_DEPS: session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH}) @@ -250,31 +241,22 @@ def run_fx_lower_tests(session): session.run_always("pytest", test) -def run_fx_quant_tests(session): - print("Running FX Quant tests") - session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) - tests = [ - "quant", - ] - # Skipping this test as it fails inside NGC container with the following error. - # ImportError: cannot import name 'ObservationType' from 'torch.ao.quantization.backend_config.observation_type' - skip_tests = "-k not conv_add_standalone_module" +def run_dynamo_partitioning_tests(session): + print("Running Dynamo Partitioning tests") + session.chdir(os.path.join(TOP_DIR, "tests/py/dynamo/")) + tests = ["partitioning"] for test in tests: if USE_HOST_DEPS: - session.run_always("pytest", test, skip_tests, env={"PYTHONPATH": PYT_PATH}) + session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH}) else: - session.run_always("pytest", test, skip_tests) + session.run_always("pytest", test) -def run_fx_tracer_tests(session): - print("Running FX Tracer tests") - session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) - # skipping a test since it depends on torchdynamo - # Enable this test once NGC moves to latest pytorch which has dynamo integrated. +def run_dynamo_runtime_tests(session): + print("Running Dynamo Runtime tests") + session.chdir(os.path.join(TOP_DIR, "tests/py/dynamo/")) tests = [ - "tracer/test_acc_shape_prop.py", - "tracer/test_acc_tracer.py", - # "tracer/test_dispatch_tracer.py" + "runtime", ] for test in tests: if USE_HOST_DEPS: @@ -283,30 +265,36 @@ def run_fx_tracer_tests(session): session.run_always("pytest", test) -def run_fx_tools_tests(session): - print("Running FX tools tests") - session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) +def run_dynamo_model_compile_tests(session): + print("Running model torch-compile tests") + session.chdir(os.path.join(TOP_DIR, "tests/py/dynamo/models")) tests = [ - "tools", + "test_models.py", ] for test in tests: if USE_HOST_DEPS: - session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH}) + session.run_always( + "python", + test, + "--ir", + str("torch_compile"), + env={"PYTHONPATH": PYT_PATH}, + ) else: - session.run_always("pytest", test) + session.run_always("python", test, "--ir", str("torch_compile")) -def run_model_tests(session): - print("Running model tests") - session.chdir(os.path.join(TOP_DIR, "tests/py/ts")) - tests = [ - "models", - ] +def run_dynamo_model_export_tests(session): + print("Running model torch-export tests") + session.chdir(os.path.join(TOP_DIR, "tests/py/dynamo/models")) + tests = ["test_models_export.py", "test_export_serde.py"] for test in tests: if USE_HOST_DEPS: - session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH}) + session.run_always( + "python", test, "--ir", str("dynamo"), env={"PYTHONPATH": PYT_PATH} + ) else: - session.run_always("pytest", test) + session.run_always("python", test, "--ir", str("dynamo")) def run_accuracy_tests(session): @@ -403,37 +391,61 @@ def run_l0_api_tests(session): cleanup(session) -def run_l0_fx_tests(session): +def run_l0_dynamo_tests(session): + if not USE_HOST_DEPS: + install_deps(session) + install_torch_trt(session) + run_dynamo_backend_tests(session) + run_dynamo_converter_tests(session) + run_dynamo_lower_tests(session) + cleanup(session) + + +def run_l0_dynamo_backend_tests(session): + if not USE_HOST_DEPS: + install_deps(session) + install_torch_trt(session) + run_dynamo_backend_tests(session) + cleanup(session) + + +def run_l0_dynamo_converter_tests(session): + if not USE_HOST_DEPS: + install_deps(session) + install_torch_trt(session) + run_dynamo_converter_tests(session) + cleanup(session) + + +def run_l0_dynamo_lower_tests(session): if not USE_HOST_DEPS: install_deps(session) install_torch_trt(session) - run_fx_core_tests(session) - run_fx_converter_tests(session) - run_fx_lower_tests(session) + run_dynamo_lower_tests(session) cleanup(session) -def run_l0_fx_core_tests(session): +def run_l0_dynamo_model_tests(session): if not USE_HOST_DEPS: install_deps(session) install_torch_trt(session) - run_fx_core_tests(session) + run_dynamo_model_tests(session) cleanup(session) -def run_l0_fx_converter_tests(session): +def run_l0_dynamo_partitioning_tests(session): if not USE_HOST_DEPS: install_deps(session) install_torch_trt(session) - run_fx_converter_tests(session) + run_dynamo_partitioning_tests(session) cleanup(session) -def run_l0_fx_lower_tests(session): +def run_l0_dynamo_runtime_tests(session): if not USE_HOST_DEPS: install_deps(session) install_torch_trt(session) - run_fx_lower_tests(session) + run_dynamo_runtime_tests(session) cleanup(session) @@ -446,12 +458,13 @@ def run_l0_dla_tests(session): cleanup(session) -def run_l1_model_tests(session): +def run_dynamo_model_tests(session): if not USE_HOST_DEPS: install_deps(session) install_torch_trt(session) download_models(session) - run_model_tests(session) + run_dynamo_model_compile_tests(session) + run_dynamo_model_export_tests(session) cleanup(session) @@ -465,13 +478,13 @@ def run_l1_int8_accuracy_tests(session): cleanup(session) -def run_l1_fx_tests(session): +def run_l1_dynamo_tests(session): if not USE_HOST_DEPS: install_deps(session) install_torch_trt(session) - run_fx_quant_tests(session) - run_fx_tracer_tests(session) - run_fx_tools_tests(session) + run_dynamo_model_tests(session) + run_dynamo_partitioning_tests(session) + run_dynamo_runtime_tests(session) cleanup(session) @@ -499,27 +512,27 @@ def l0_api_tests(session): @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) -def l0_fx_tests(session): +def l0_dynamo_tests(session): """When a developer needs to check correctness for a PR or something""" - run_l0_fx_tests(session) + run_l0_dynamo_tests(session) @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) -def l0_fx_core_tests(session): +def l0_dynamo_backend_tests(session): """When a developer needs to check correctness for a PR or something""" - run_l0_fx_core_tests(session) + run_l0_dynamo_backend_tests(session) @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) -def l0_fx_converter_tests(session): +def l0_dynamo_converter_tests(session): """When a developer needs to check correctness for a PR or something""" - run_l0_fx_converter_tests(session) + run_l0_dynamo_converter_tests(session) @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) -def l0_fx_lower_tests(session): +def l0_dynamo_lower_tests(session): """When a developer needs to check correctness for a PR or something""" - run_l0_fx_lower_tests(session) + run_l0_dynamo_lower_tests(session) @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) @@ -531,13 +544,13 @@ def l0_dla_tests(session): @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) def l1_model_tests(session): """When a user needs to test the functionality of standard models compilation and results""" - run_l1_model_tests(session) + run_dynamo_model_tests(session) @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) -def l1_fx_tests(session): +def l1_dynamo_tests(session): """When a user needs to test the functionality of standard models compilation and results""" - run_l1_fx_tests(session) + run_l1_dynamo_tests(session) @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) diff --git a/py/requirements.txt b/py/requirements.txt index 56257d2414..b993362038 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -1,7 +1,8 @@ numpy packaging pybind11==2.6.2 -torch==2.1.2 +--extra-index-url https://download.pytorch.org/whl/test/cu121 +torch==2.2.0 torchvision==0.16.2 --extra-index-url https://pypi.nvidia.com tensorrt==8.6.1 diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 6e43a23903..9acb073c62 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -47,6 +47,7 @@ class _ShapeMode(Enum): high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET torch_dtype: torch.dtype = torch.float32 torch_tensor: torch.Tensor = None + name: str = "" def __init__(self, *args: Any, **kwargs: Any) -> None: """__init__ Method for torch_tensorrt.Input @@ -68,7 +69,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]). Note: Entering "None" (or not specifying) will set the bound to [0, 2) - + torch_tensor (torch.Tensor): Holds a corresponding torch tensor with this Input. + name (str, optional): Name of this input in the input nn.Module's forward function. Used to specify dynamic shapes for the corresponding input in dynamo tracer. Examples: - Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last) - Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW) @@ -180,6 +182,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: else: self.torch_tensor = self.example_tensor() + if "name" in kwargs: + self.name = kwargs["name"] + def __str__(self) -> str: if self.shape_mode == Input._ShapeMode.STATIC: return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format( diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d91ddab15f..720941d619 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -2,11 +2,12 @@ import collections.abc import logging -from typing import Any, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import torch import torch_tensorrt from torch.export import ExportedProgram +from torch.fx.node import Target from torch_tensorrt._Device import Device from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum EngineCapability, @@ -42,7 +43,10 @@ convert_module, repair_long_or_double_inputs, ) -from torch_tensorrt.dynamo.lowering import apply_lowering_passes +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_CONVERTERS as CONVERTERS, +) +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.utils import ( get_torch_inputs, prepare_inputs, @@ -55,8 +59,8 @@ def compile( - exported_program: Union[torch.fx.GraphModule, ExportedProgram], - inputs: Any, + exported_program: ExportedProgram, + inputs: Tuple[Any, ...], *, device: Optional[Union[Device, torch.device, str]] = DEVICE, disable_tf32: bool = DISABLE_TF32, @@ -75,7 +79,7 @@ def compile( truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, min_block_size: int = MIN_BLOCK_SIZE, - torch_executed_ops: Optional[List[str]] = None, + torch_executed_ops: Optional[Collection[Target]] = None, torch_executed_modules: Optional[List[str]] = None, pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, max_aux_streams: Optional[int] = MAX_AUX_STREAMS, @@ -131,7 +135,7 @@ def compile( calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT - torch_executed_ops (List[str]): List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + torch_executed_ops (Collection[Target]): Set of aten operators that must be run in PyTorch. An error will be thrown if this set is not empty but ``require_full_compilation`` is True torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True pass_through_build_failures (bool): Error out if there are issues during compilation (only applicable to torch.compile workflows) max_aux_stream (Optional[int]): Maximum streams in the engine @@ -155,15 +159,14 @@ def compile( inputs = prepare_inputs(inputs) device = to_torch_tensorrt_device(device) - if isinstance(exported_program, torch.fx.GraphModule): - gm = exported_program - elif isinstance(exported_program, ExportedProgram): - gm = exported_program.module() - else: + if not isinstance(exported_program, ExportedProgram): raise AssertionError( - f"Input graph should either be an ExportedProgram or a GraphModule but got type {type(exported_program)}" + f"Input graph should be an ExportedProgram but got type {type(exported_program)}" ) - + exported_program = exported_program.run_decompositions( + get_decompositions(enable_experimental_decompositions) + ) + gm = exported_program.module() logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module @@ -199,7 +202,7 @@ def compile( "min_block_size": min_block_size, "torch_executed_ops": torch_executed_ops if torch_executed_ops is not None - else [], + else set(), "pass_through_build_failures": pass_through_build_failures, "max_aux_streams": max_aux_streams, "version_compatible": version_compatible, @@ -240,6 +243,9 @@ def compile_module( Compiled FX GraphModule """ + # Set torch-executed ops + CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) + # Check the number of supported operations in the graph num_supported_ops, total_ops = partitioning.get_graph_converter_support( gm, settings.debug, settings.torch_executed_ops diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index b496f94a74..df9150ea2d 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -3,10 +3,16 @@ from typing import Any, Dict, Sequence, Tuple, cast import torch -from torch._export.exported_program import CallSpec from torch._guards import detect_fake_mode from torch._subclasses.fake_tensor import FakeTensor from torch.export import ExportedProgram, ExportGraphSignature +from torch.export.exported_program import ( + InputKind, + InputSpec, + OutputKind, + OutputSpec, + TensorArgument, +) from torch_tensorrt.dynamo import partitioning @@ -14,7 +20,6 @@ def export( gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor], *, - call_spec: CallSpec = None, ir: str = "torchscript", ) -> ExportedProgram: """Export a program (``torch.fx.GraphModule``) for serialization with the TensorRT engines embedded. @@ -40,15 +45,13 @@ def export( format=torch.channel_last ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings - call_spec (CallSpec): CallSpec object of exported program. This is None for ir=torchscript. For ir=exported_program, this should be set to input ExportedProgram's call_spec object. ir (str): torchscript | exported_program. Based on the provided ir, the output type would be a torchscript or exported program. """ if ir == "torchscript": return torch.jit.trace(gm, inputs) elif ir == "exported_program": - assert call_spec patched_module = transform(gm, inputs) - exp_program = create_trt_exp_program(patched_module, call_spec) + exp_program = create_trt_exp_program(patched_module) return exp_program else: @@ -162,6 +165,7 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: # Copy all nodes in the submodule into gm and # store the output node of this submodule which is now present in gm + submodule_output = gm.graph.graph_copy(submodule.graph, val_map) # Get their references (since we copied) in the parent graph (gm) @@ -218,39 +222,31 @@ def copy_submodule_attributes(gm: torch.fx.GraphModule, submod_name: str) -> Non def create_trt_exp_program( - gm: torch.fx.GraphModule, call_spec: CallSpec + gm: torch.fx.GraphModule, ) -> ExportedProgram: """Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines and constructs an Exported Program object with the new IO node names and state_dict """ - input_node_names = [ - node.name for node in gm.graph.nodes if node.op == "placeholder" + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + output_nodes = [node for node in gm.graph.nodes if node.op == "output"] + assert output_nodes + output_nodes = output_nodes[0].args[0] + + input_specs = [ + InputSpec(InputKind.USER_INPUT, TensorArgument(name=node.name), node.target) + for node in input_nodes + ] + output_specs = [ + OutputSpec(OutputKind.USER_OUTPUT, TensorArgument(name=node.name), node.target) + for node in output_nodes ] - output_node_names = [node.name for node in gm.graph.nodes if node.op == "output"] - param_names = [param[0] for param in gm.named_parameters()] - buffer_names = [buffer[0] for buffer in gm.named_buffers()] - inputs_to_parameters = {} - inputs_to_buffers = {} - for node in gm.graph.nodes: - if node.target in param_names: - inputs_to_parameters[node.name] = node.target - if node.target in buffer_names: - inputs_to_buffers[node.name] = node.target trt_graph_signature = ExportGraphSignature( - parameters=param_names, - buffers=buffer_names, - user_inputs=input_node_names, - user_outputs=output_node_names, - inputs_to_parameters=inputs_to_parameters, - inputs_to_buffers=inputs_to_buffers, - buffers_to_mutate={}, - backward_signature=None, - assertion_dep_token=None, + input_specs=input_specs, output_specs=output_specs ) trt_exp_program = ExportedProgram( - gm, gm.graph, trt_graph_signature, call_spec, gm.state_dict(), {}, [], [] + gm, gm.graph, trt_graph_signature, gm.state_dict(), {}, [], [], [] ) return trt_exp_program @@ -282,6 +278,7 @@ def inline_trt_modules( (trt_module_node.args, trt_module.engine), ) trt_node.meta["val"] = [] + assert num_outputs > 0 # Generate meta data for TRT node (a FakeTensor with corresponding output shape) for idx in range(num_outputs): trt_node.meta["val"].append( @@ -298,12 +295,16 @@ def inline_trt_modules( # Insert getitem nodes as outputs (for export serialization to work) with gm.graph.inserting_after(trt_node): getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0)) + getitem_output.meta["val"] = trt_node.meta["val"] trt_module_node.replace_all_uses_with(getitem_output) else: # Multiple outputs case: # Replace uses of submodule with the trt_node. # getitem nodes are already added inherently by the partitioner trt_module_node.replace_all_uses_with(trt_node) + getitem_nodes = trt_node.users + for idx, getitem_node in enumerate(getitem_nodes): + getitem_node.meta["val"] = trt_node.meta["val"][idx] # Erase the TRT submodule (call_module) node. gm.graph.erase_node(trt_module_node) diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index cd58c9547f..2d5ffa2ede 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,8 +1,9 @@ from dataclasses import dataclass, field -from typing import Optional, Set +from typing import Collection, Optional import torch from tensorrt import EngineCapability +from torch.fx.node import Target from torch_tensorrt._Device import Device from torch_tensorrt.dynamo._defaults import ( DEBUG, @@ -39,7 +40,7 @@ class CompilationSettings: debug (bool): Whether to print out verbose debugging information workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) min_block_size (int): Minimum number of operators per TRT-Engine Block - torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage + torch_executed_ops (Collection[Target]): Collection of operations to run in Torch, regardless of converter coverage pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False) max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine version_compatible (bool): Provide version forward-compatibility for engine plan files @@ -69,7 +70,7 @@ class CompilationSettings: debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE - torch_executed_ops: Set[str] = field(default_factory=set) + torch_executed_ops: Collection[Target] = field(default_factory=set) pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES max_aux_streams: Optional[int] = MAX_AUX_STREAMS version_compatible: bool = VERSION_COMPATIBLE diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 5fdca08399..339b55b6b3 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -1,45 +1,20 @@ from __future__ import annotations import logging -import unittest.mock -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Tuple import torch -from torch._export import dynamic_dim, export -from torch_tensorrt._Device import Device +from torch.export import Dim, export from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import ( - DEBUG, - DEVICE, - ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - default_device, -) -from torch_tensorrt.dynamo.lowering import get_decompositions +from torch_tensorrt.dynamo._defaults import DEBUG, default_device from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device logger = logging.getLogger(__name__) -def get_random_tensor( - shape: List[Any], dtype: torch.dtype, device: torch.device -) -> torch.Tensor: - if dtype == torch.int32 or dtype == torch.int64: - return torch.randint(2, 10, shape, dtype=dtype, device=device) - elif dtype in (torch.float64, torch.float32, torch.float16): - return torch.randn(shape, dtype=dtype, device=device) - else: - logger.critical( - "Invalid dtype detected in creating input tensors for tracing the graph." - ) - raise - - def trace( mod: torch.nn.Module | torch.fx.GraphModule, inputs: Tuple[Any, ...], - device: Optional[Union[Device, torch.device, str]] = DEVICE, - debug: bool = DEBUG, - enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, **kwargs: Any, ) -> torch.export.ExportedProgram: """Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT @@ -65,9 +40,9 @@ def trace( torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] Keyword Arguments: - device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: + device (Union(torch.device, dict)): Target device for TensorRT engines to run on :: - device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) + device=torch.device("cuda:0") debug (bool): Enable debuggable engine enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT. @@ -77,50 +52,36 @@ def trace( """ # Set log level at the top of compilation (torch_tensorrt.dynamo) + debug = kwargs.get("debug", DEBUG) if debug: set_log_level(logger.parent, logging.DEBUG) - device = to_torch_device(device if device else default_device()) - # Determine the dynamic dimension and setup constraints to input dimensions as dictated by TensorRT - # Torch dynamo does not allow 0/1 value for dynamic dimensions - # for inputs during tracing. Hence we create new inputs for export + device = to_torch_device(kwargs.get("device", default_device())) torch_inputs = get_torch_inputs(inputs, device) - trace_inputs = [] - constraints = [] - for idx, input in enumerate(inputs): - if input.shape_mode == Input._ShapeMode.DYNAMIC: - min_shape = input.shape["min_shape"] - opt_shape = input.shape["opt_shape"] - max_shape = input.shape["max_shape"] + dynamic_shapes = {} + for input in inputs: + if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC: + if not input.name: + raise AssertionError( + f"Expected a name for a dynamic input with shape {input.shape} but found none" + ) + min_shape = input.shape["min_shape"] # type: ignore + opt_shape = input.shape["opt_shape"] # type: ignore + max_shape = input.shape["max_shape"] # type: ignore assert len(min_shape) == len(opt_shape) == len(max_shape) - - constraint_dims = [] - new_shape = [] + dynamic_dims = {} for dim in range(len(min_shape)): if min_shape[dim] == opt_shape[dim] == max_shape[dim]: - new_shape.append(torch_inputs[idx].shape[dim]) + continue else: - constraint_dims.append(dim) - if torch_inputs[idx].shape[dim] == 1: - new_shape.append(torch_inputs[idx].shape[dim] + 1) - else: - new_shape.append(torch_inputs[idx].shape[dim]) - - trace_input = get_random_tensor(new_shape, torch_inputs[idx].dtype, device) + dynamic_dims[dim] = Dim( + input.name + "_" + str(dim), + min=min_shape[dim], + max=max_shape[dim], + ) - for dim in constraint_dims: - if min_shape[dim] > 1: - constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim)) - if max_shape[dim] > 1: - constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim]) - trace_inputs.append(trace_input) - else: - trace_inputs.append(torch_inputs[idx]) + dynamic_shapes[input.name] = dynamic_dims - with unittest.mock.patch( - "torch._export.DECOMP_TABLE", - get_decompositions(enable_experimental_decompositions), - ): - exp_program = export(mod, tuple(trace_inputs), constraints=constraints) + exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes) return exp_program diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index d689de3e54..050a62ef3e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -6,6 +6,7 @@ from typing import ( Any, Callable, + Collection, Dict, List, Optional, @@ -212,8 +213,16 @@ def __init__( CallingConvention.CTX for _ in range(len(self.registries)) ] + self.disallowed_targets: Collection[Target] = set() + self.validate_invariants() + def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None: + self.disallowed_targets = torch_executed_ops + + def get_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None: + self.disallowed_targets = torch_executed_ops + def validate_invariants(self) -> None: """Validates the invariants required of the dictionaries in the registries @@ -253,6 +262,14 @@ def __getitem_without_validation__( self.validate_invariants() + if ( + key in self.disallowed_targets + or self.qualified_name_or_str(key) in self.disallowed_targets + ): + raise KeyError( + f"A converter exists for {key}, but it was " "explicitly disallowed" + ) + # Iterate over all registries and return the first converter found for registry, calling_convention in zip( self.registries, self.registry_calling_conventions @@ -288,6 +305,14 @@ def __getitem__( self.validate_invariants() key = node.target + if ( + key in self.disallowed_targets + or self.qualified_name_or_str(key) in self.disallowed_targets + ): + raise KeyError( + f"A converter exists for {key}, but it was " "explicitly disallowed" + ) + # Iterate over all registries, validating the converter on the input node # If no capability_validator function is found, assume full coverage for registry, calling_convention in zip( diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 33f0de0ead..0592160a4c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -330,6 +330,34 @@ def aten_ops_fmod( return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1]) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 1: (TRTTensor,), + } +) +def aten_ops_grid( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.grid.grid( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + grid=args[1], + interpolation_mode=args[2], + padding_mode=args[3], + align_corners=args[4], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.relu.default) def aten_ops_relu( ctx: ConversionContext, @@ -364,7 +392,19 @@ def aten_ops_sigmoid( ) -@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor) +def index_dtype_validator(node: Node) -> bool: + index = node.args[1] + for ind in index: + if ind is not None: + val = ind.meta.get("val") + if val is not None and val.dtype != torch.int32: + return False + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator +) @enforce_tensor_types( { 0: (TRTTensor,), @@ -459,25 +499,6 @@ def aten_ops_softplus( ) -@dynamo_tensorrt_converter(torch.ops.aten.clip.default) -def aten_ops_clip( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.activation.clip( - ctx, - target, - SourceIR.ATEN, - name, - args[0], - alpha=args_bounds_check(args, 1), - beta=args_bounds_check(args, 2), - ) - - @dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) def aten_ops_hard_sigmoid( ctx: ConversionContext, @@ -655,6 +676,9 @@ def aten_ops_where( @dynamo_tensorrt_converter(torch.ops.aten.clamp.default) +@dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.clip.default) +@dynamo_tensorrt_converter(torch.ops.aten.clip.Tensor) def aten_ops_clamp( ctx: ConversionContext, target: Target, @@ -759,6 +783,29 @@ def aten_ops_cumsum( ) +@dynamo_tensorrt_converter(torch.ops.aten.tile.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_tile( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.tile( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.permute.default) @enforce_tensor_types( { @@ -848,7 +895,7 @@ def aten_ops_clone_copy_dtype( name, args[0], kwargs.get("dtype", args[0].dtype), - force_layer=False, + force_layer=True, ) @@ -904,18 +951,11 @@ def aten_ops_expand( ) -def amax_param_validator(amax_node: Node) -> bool: - if len(amax_node.args) < 2: - _LOGGER.debug( - f"At least two args input and dim should be provided, but only got {len(amax_node.args)} args." - ) - return False - - return True - - -@dynamo_tensorrt_converter( - torch.ops.aten.amax.default, capability_validator=amax_param_validator +@dynamo_tensorrt_converter(torch.ops.aten.amax.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } ) def aten_ops_amax( ctx: ConversionContext, @@ -935,6 +975,30 @@ def aten_ops_amax( ) +@dynamo_tensorrt_converter(torch.ops.aten.amin.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_amin( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.reduce.amin( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args_bounds_check(args, 1, replacement=[]), + args_bounds_check(args, 2, replacement=False), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.sum.default) @dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) @dynamo_tensorrt_converter(torch.ops.prims.sum.default) @@ -963,7 +1027,7 @@ def aten_ops_sum( name, sum_, kwargs["output_dtype"], - force_layer=False, + force_layer=True, ) else: return sum_ @@ -1688,9 +1752,177 @@ def aten_ops_logical_xor( ) +def bitwise_type_validator(node: Node) -> bool: + supported_type = [torch.bool, bool] + + tensor_targets = [ + torch.ops.aten.bitwise_and.Tensor, + torch.ops.aten.bitwise_or.Tensor, + torch.ops.aten.bitwise_xor.Tensor, + ] + scalar_targets = [ + torch.ops.aten.bitwise_and.Scalar, + torch.ops.aten.bitwise_or.Scalar, + torch.ops.aten.bitwise_xor.Scalar, + ] + scalar_tensor_targets = [ + torch.ops.aten.bitwise_and.Scalar_Tensor, + torch.ops.aten.bitwise_or.Scalar_Tensor, + torch.ops.aten.bitwise_xor.Scalar_Tensor, + ] + + if node.target in tensor_targets: + lhs_val = node.args[0] + rhs_val = node.args[1] + lhs_meta = lhs_val.meta.get("tensor_meta") + rhs_meta = rhs_val.meta.get("tensor_meta") + if lhs_meta is None or rhs_meta is None: + return False + return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type + + elif node.target in scalar_targets: + lhs_val = node.args[0] + rhs_val = node.args[1] + lhs_meta = lhs_val.meta.get("tensor_meta") + if lhs_meta is None: + return False + return lhs_meta.dtype in supported_type and isinstance(rhs_val, bool) + + elif node.target in scalar_tensor_targets: + lhs_val = node.args[0] + rhs_val = node.args[1] + rhs_meta = rhs_val.meta.get("tensor_meta") + if rhs_meta is None: + return False + return isinstance(lhs_val, bool) and rhs_meta.dtype in supported_type + + else: + return False + + +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_and.Scalar, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_and.Scalar_Tensor, + capability_validator=bitwise_type_validator, +) +def aten_ops_bitwise_and( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.bitwise_and( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_or.Scalar, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_or.Scalar_Tensor, capability_validator=bitwise_type_validator +) +def aten_ops_bitwise_or( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.bitwise_or( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_xor.Scalar, capability_validator=bitwise_type_validator +) +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_xor.Scalar_Tensor, + capability_validator=bitwise_type_validator, +) +def aten_ops_bitwise_xor( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.bitwise_xor( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +def bitwise_not_type_validator(node: Node) -> bool: + val = node.args[0] + val_meta = val.meta.get("tensor_meta") + + if val_meta is None: + return False + + supported_type = [torch.bool, bool] + return val_meta.dtype in supported_type + + +@dynamo_tensorrt_converter( + torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator +) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_bitwise_not( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.bitwise_not( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) @dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) -def aten_ops_equal( +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_eq( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -1707,9 +1939,38 @@ def aten_ops_equal( ) +@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_ne( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.ne( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) @dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) -def aten_ops_greater( +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_gt( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -1726,9 +1987,38 @@ def aten_ops_greater( ) +@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_ge( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.ge( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) @dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) -def aten_ops_less( +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_lt( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -1745,6 +2035,30 @@ def aten_ops_less( ) +@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) +@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_le( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.le( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + def conv_param_validator(conv_node: Node) -> bool: return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) @@ -1776,7 +2090,7 @@ def aten_ops_convolution( is_conv1d=len(args[3]) == 1, input=args[0], weight=args[1], - bias=args[2], + bias=args_bounds_check(args, 2, None), stride=args[3], padding=args[4], dilation=args[5], @@ -1791,7 +2105,7 @@ def aten_ops_convolution( is_deconv1d=len(args[3]) == 1, input=args[0], weight=args[1], - bias=args[2], + bias=args_bounds_check(args, 2, None), stride=args[3], padding=args[4], dilation=args[5], @@ -1966,7 +2280,27 @@ def aten_ops_argmax( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.argmax.argmax( + return impl.topk.argmax( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + dim=args_bounds_check(args, 1), + keep_dim=args_bounds_check(args, 2, False), + ) + + +@enforce_tensor_types({0: (TRTTensor,)}) +@dynamo_tensorrt_converter(torch.ops.aten.argmin.default) +def aten_ops_argmin( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.topk.argmin( ctx, target, SourceIR.ATEN, @@ -1975,3 +2309,266 @@ def aten_ops_argmax( dim=args_bounds_check(args, 1), keep_dim=args_bounds_check(args, 2, False), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 1: (np.ndarray, torch.Tensor, TRTTensor), + 2: (np.ndarray, torch.Tensor, TRTTensor), + } +) +def aten_ops_addmm( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.addmm.addmm( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + beta=kwargs.get("beta", 1), + alpha=kwargs.get("alpha", 1), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.constant_pad_nd.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_constant_pad( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pad.constant_padNd( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args_bounds_check(args, 2, 0), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad1d.default) +@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad2d.default) +@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad3d.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_reflection_pad( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pad.reflection_padNd( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.replication_pad1d.default) +@dynamo_tensorrt_converter(torch.ops.aten.replication_pad2d.default) +@dynamo_tensorrt_converter(torch.ops.aten.replication_pad3d.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_replication_pad( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pad.replication_padNd( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten._pad_circular.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_circular_pad( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pad.circular_padNd( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.pad.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_pad( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pad.pad( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + pad=args[1], + mode=args_bounds_check(args, 2, "constant"), + value=args_bounds_check(args, 3, None), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.vec) +def upsample_nearest2d( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.upsample.upsample( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + out_shape=args_bounds_check(args, 1), + scale_factors=args_bounds_check(args, 2), + resize_mode="nearest", + align_corners=False, + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.vec) +def upsample_bilinear2d( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.upsample.upsample( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + out_shape=args_bounds_check(args, 1), + scale_factors=args_bounds_check(args, 3), + resize_mode="bilinear", + align_corners=args_bounds_check(args, 2), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.sort.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_sort( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.topk.sort( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + dim=args_bounds_check(args, 1, -1), + descending=args_bounds_check(args, 2, False), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.trunc.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_trunc( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.trunc( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.copy.default) +@enforce_tensor_types( + { + 1: (TRTTensor,), + } +) +def aten_ops_copy( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + src = args[1] + return impl.cast.to_copy( + ctx, + target, + SourceIR.ATEN, + name, + src, + src.dtype, + force_layer=True, + ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 4f56ffbd85..f90c869c15 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -180,8 +180,7 @@ def cast_int_int_div_trt_tensor( def broadcastable( - a: TRTTensor, - b: TRTTensor, + a: Union[TRTTensor, np.ndarray], b: Union[TRTTensor, np.ndarray] ) -> bool: "Check if two tensors are broadcastable according to torch rules" a_shape = tuple(a.shape) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 26688695fa..ca71cb0b0c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -2,8 +2,8 @@ from . import ( activation, + addmm, attention, - argmax, cast, cat, condition, @@ -11,9 +11,11 @@ deconv, elementwise, embedding, + grid, linear, matmul, normalization, + pad, permutation, pool, reduce, @@ -23,6 +25,8 @@ slice, split, squeeze, + topk, unary, unsqueeze, + upsample, ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py index ac77f790cb..f578351ef2 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py @@ -235,36 +235,6 @@ def softplus_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float] ) -def clip( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input_val: TRTTensor, - alpha: float, - beta: float, -) -> TRTTensor: - operation_type = trt.ActivationType.CLIP - - def clip_dyn_range_fn(dyn_range: Tuple[float, float]) -> Tuple[float, float]: - def clip_fn(x: float) -> float: - return max(alpha, min(beta, x)) - - return clip_fn(dyn_range[0]), clip_fn(dyn_range[1]) - - return convert_activation( - ctx, - target, - source_ir, - name, - operation_type, - input_val, - alpha=alpha, - beta=beta, - dyn_range_fn=clip_dyn_range_fn, - ) - - def hard_sigmoid( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/addmm.py b/py/torch_tensorrt/dynamo/conversion/impl/addmm.py new file mode 100644 index 0000000000..1a0690852a --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/addmm.py @@ -0,0 +1,34 @@ +from typing import Optional, Union + +import numpy as np +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.fx.types import TRTTensor + + +def addmm( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + mat1: Union[TRTTensor, torch.Tensor, np.ndarray], + mat2: Union[TRTTensor, torch.Tensor, np.ndarray], + *, + beta: Union[float, int], + alpha: Union[float, int], +) -> TRTTensor: + mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2) + if alpha != 1: + mm = impl.elementwise.mul( + ctx, target, SourceIR.ATEN, f"{name}_mul_alpha", mm, alpha + ) + if beta != 1: + input = impl.elementwise.mul( + ctx, target, SourceIR.ATEN, f"{name}_mul_beta", input, beta + ) + + return impl.elementwise.add(ctx, target, source_ir, f"{name}_add", input, mm) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 24149d01b0..d6ffc77377 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Sequence, Union +from typing import Optional, Sequence, Union import numpy as np import torch @@ -6,12 +6,11 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( - SourceIR, get_positive_dim, get_trt_tensor, ) from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def cat( @@ -23,12 +22,12 @@ def cat( dim: int, ) -> Union[TRTTensor, Sequence[TRTTensor]]: trt_inputs = [] - for each_input in input: + for i, each_input in enumerate(input): if not isinstance(each_input, TRTTensor): - each_input = get_trt_tensor(ctx, each_input, name + "_tensor_{i}") + each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}") trt_inputs.append(each_input) concat_layer = ctx.net.add_concatenation(trt_inputs) dim = get_positive_dim(dim, len(input[0].shape)) concat_layer.axis = dim - set_layer_name(concat_layer, target, name + "_gather", source_ir) + set_layer_name(concat_layer, target, f"{name}_gather", source_ir) return concat_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 33b5fcbd87..26e0d59b8f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -32,10 +32,11 @@ def convNd( input: TRTTensor, weight: Union[TRTTensor, torch.Tensor, np.ndarray], bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], - stride: Optional[Union[int, Sequence[int]]], - padding: Optional[Union[int, Sequence[int]]], - dilation: Optional[Union[int, Sequence[int]]], + stride: Union[int, Sequence[int]], + padding: Union[int, Sequence[int]], + dilation: Union[int, Sequence[int]], groups: Optional[int], + output_padding: Union[int, Sequence[int]] = 0, scale: Optional[Union[torch.Tensor, float]] = None, zero_point: Optional[Union[torch.Tensor, float]] = None, ) -> TRTTensor: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index ebb9b1bec2..f66bff7c82 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -32,10 +32,11 @@ def deconvNd( input: TRTTensor, weight: Union[TRTTensor, torch.Tensor, np.ndarray], bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], - stride: Optional[Union[int, Sequence[int]]], - padding: Optional[Union[int, Sequence[int]]], + stride: Union[int, Sequence[int]], + padding: Union[int, Sequence[int]], + dilation: Union[int, Sequence[int]], groups: Optional[int], - dilation: Optional[Union[int, Sequence[int]]], + output_padding: Union[int, Sequence[int]] = 0, scale: Optional[Union[torch.Tensor, float]] = None, zero_point: Optional[Union[torch.Tensor, float]] = None, ) -> TRTTensor: @@ -86,7 +87,7 @@ def deconvNd( # add deconv layer deconv_layer = ctx.net.add_deconvolution_nd( input=input, - num_output_maps=weight.shape[0], + num_output_maps=weight.shape[1] * groups, kernel_shape=weight.shape[2:], kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight, bias=trt.Weights() if isinstance(bias, TRTTensor) else bias, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 3700242fe7..8282ee8698 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -58,8 +58,8 @@ def convert_binary_elementwise( source_ir: Optional[SourceIR], name: str, op_type: trt.ElementWiseOperation, - lhs_val: Union[int, float, TRTTensor, torch.Tensor], - rhs_val: Union[int, float, TRTTensor, torch.Tensor], + lhs_val: Union[int, float, bool, TRTTensor, torch.Tensor], + rhs_val: Union[int, float, bool, TRTTensor, torch.Tensor], ) -> TRTTensor: """ This function adds a TensorRT elementwise layer. We allow both operands to be @@ -120,11 +120,11 @@ def convert_binary_elementwise( # Note that the dtype here is supposed to be the same as the scalar # dtype but we don't have a way to detect whether it makes sense for the # scalar to be float or half. Hence we go with the lhs dtype. - if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): + if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)): rhs_val = np.array( [rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY) ) - if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): + if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)): lhs_val = np.array( [lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY) ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 9f1143959f..a69fca944b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -1,7 +1,8 @@ from typing import Optional, Union -import numpy as np import tensorrt as trt +import torch +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -15,7 +16,6 @@ ) from torch_tensorrt.dynamo.conversion.impl.unary import sign from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary -from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left from torch_tensorrt.fx.types import TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter @@ -184,63 +184,21 @@ def clamp( source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, - min_val: Optional[float] = None, - max_val: Optional[float] = None, + min_val: Optional[Union[int, float, TRTTensor]] = None, + max_val: Optional[Union[int, float, TRTTensor]] = None, ) -> TRTTensor: - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Clamp received input {input_val} that is not part " - "of the TensorRT region!" - ) - - def _add_layer( - ctx: ConversionContext, - input: TRTTensor, - val: float, - op: trt.ElementWiseOperation, - name: str, - ) -> ( - trt.ILayer - ): # TODO: Simplify and merge implementations, should just be max and min stacked - if not len(input.shape): - # clamping scalar - acc_ops_clamp_trt = get_trt_tensor( - ctx, - squeeze_left( - np.array( - [val], - dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), - ) - ), - f"{name}_clamp_{val}", - ) - else: - acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions - acc_ops_clamp_tensor = np.full( - acc_ops_clamp_shape, - val, - dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), - ) - acc_ops_clamp_trt = ctx.net.add_constant( - acc_ops_clamp_shape, acc_ops_clamp_tensor - ).get_output(0) - layer = ctx.net.add_elementwise(input, acc_ops_clamp_trt, op) - return layer - + clamped_val = input_val if min_val is not None: - clamp_min_layer = _add_layer( - ctx, input_val, min_val, trt.ElementWiseOperation.MAX, name + clamped_val = impl.elementwise.max( + ctx, target, source_ir, f"{name}_max", clamped_val, min_val ) - set_layer_name(clamp_min_layer, target, f"{name}_clamp_min") - input_val = clamp_min_layer.get_output(0) + if max_val is not None: - clamp_max_layer = _add_layer( - ctx, input_val, max_val, trt.ElementWiseOperation.MIN, name + clamped_val = impl.elementwise.min( + ctx, target, source_ir, f"{name}_min", clamped_val, max_val ) - set_layer_name(clamp_max_layer, target, f"{name}_clamp_max") - input_val = clamp_max_layer.get_output(0) - return input_val + return clamped_val def add( @@ -370,8 +328,8 @@ def logical_and( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: Union[TRTTensor, int, float, bool], + rhs_val: Union[TRTTensor, int, float, bool], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) @@ -389,8 +347,8 @@ def logical_or( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: Union[TRTTensor, int, float, bool], + rhs_val: Union[TRTTensor, int, float, bool], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) @@ -408,8 +366,8 @@ def logical_xor( target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: Union[TRTTensor, int, float, bool], + rhs_val: Union[TRTTensor, int, float, bool], ) -> TRTTensor: if isinstance(lhs_val, TRTTensor): lhs_val = cast_int_or_float_to_bool(ctx, name, lhs_val) @@ -422,13 +380,46 @@ def logical_xor( ) +def bitwise_and( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], + rhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], +) -> TRTTensor: + return logical_and(ctx, target, source_ir, f"{name}_logical_and", lhs_val, rhs_val) + + +def bitwise_or( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], + rhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], +) -> TRTTensor: + return logical_or(ctx, target, source_ir, f"{name}_logical_or", lhs_val, rhs_val) + + +def bitwise_xor( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], + rhs_val: Union[TRTTensor, int, float, torch.Tensor, bool], +) -> TRTTensor: + return logical_xor(ctx, target, source_ir, f"{name}_logical_xor", lhs_val, rhs_val) + + def eq( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -441,13 +432,30 @@ def eq( ) +def ne( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], +) -> TRTTensor: + return impl.unary.logical_not( + ctx, + target, + source_ir, + f"{name}_logical_not", + eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val), + ) + + def gt( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -460,13 +468,31 @@ def gt( ) +def ge( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], +) -> TRTTensor: + return logical_or( + ctx, + target, + source_ir, + name, + gt(ctx, target, source_ir, f"{name}_gt", lhs_val, rhs_val), + eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val), + ) + + def lt( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, - lhs_val: Union[TRTTensor, int, float], - rhs_val: Union[TRTTensor, int, float], + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], ) -> TRTTensor: return convert_binary_elementwise( ctx, @@ -477,3 +503,21 @@ def lt( lhs_val, rhs_val, ) + + +def le( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + lhs_val: TRTTensor, + rhs_val: Union[TRTTensor, int, float, torch.Tensor], +) -> TRTTensor: + return logical_or( + ctx, + target, + source_ir, + name, + lt(ctx, target, source_ir, f"{name}_lt", lhs_val, rhs_val), + eq(ctx, target, source_ir, f"{name}_eq", lhs_val, rhs_val), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index b7795ea1f3..ac9faf9f4d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -1,10 +1,13 @@ -from typing import Optional +import functools +from typing import Optional, Sequence, Tuple, Union +import numpy as np import torch +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -40,5 +43,133 @@ def embedding( # Implement embedding lookup with gather layer gather_layer = ctx.net.add_gather(embedding_tensor, indices_tensor, axis=0) - set_layer_name(gather_layer, target, name + "_gather", source_ir) + set_layer_name(gather_layer, target, f"{name}_gather", source_ir) return gather_layer.get_output(0) + + +def embedding_bag( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weight: TRTTensor, + indices: TRTTensor, + offsets: Union[torch.Tensor, np.ndarray, Sequence[int]], + scale_grad_by_freq: bool, + mode: int, + sparse: bool, + per_sample_weights: Optional[TRTTensor], + include_last_offset: bool, +) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]: + """ + This function is for calculating embedding bags. + + In PyTorch, `offsets` is only used when input is 1D. If input is 2D of shape (B, N), + it will be treated as B bags (sequences) each of fixed length N, and this will return + B values aggregated in a way depending on the mode. `offsets` is ignored and required + to be None in this case. + + However, according to the schema, `offsets` is required for input with any dimensions. + Accordingly, this function flattens N-D input to 1D and then to calculate embedding bags. + """ + + # TODO: support 2D inputs + # indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,)) + reduce_name = "" + if mode == 0: # sum + reduce_op = functools.partial( + impl.reduce.sum, ctx=ctx, target=target, source_ir=source_ir + ) + reduce_name = "sum" + elif mode == 1: # mean + reduce_op = functools.partial( + impl.reduce.mean, ctx=ctx, target=target, source_ir=source_ir + ) + reduce_name = "mean" + elif mode == 2: # max + reduce_op = functools.partial( + impl.reduce.max, + ctx=ctx, + target=target, + source_ir=source_ir, + return_indices=False, + ) + reduce_name = "max" + + # calculate embedding + embed = embedding( + ctx, + target, + source_ir, + f"{name}_embedding", + indices, + weight, + scale_grad_by_freq, + sparse, + ) + + # give weights to embedding + if per_sample_weights is not None: + assert ( + per_sample_weights.shape == indices.shape + ), f"`per_sample_weights` (shape: {per_sample_weights.shape}) must have exactly the same shape as indices/input (shape: {indices.shape})!" + per_sample_weights = get_trt_tensor( + ctx, per_sample_weights, f"{name}_per_sample_weights", np.float32 + ) + per_sample_weights = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_per_sample_weights", + per_sample_weights, + (-1, 1), + ) + embed = impl.elementwise.mul( + ctx, + target, + source_ir, + f"{name}_mul_per_sample_weights", + embed, + per_sample_weights, + ) + + offsets = to_numpy(offsets) + + if include_last_offset is False: + # add the end index to offsets + offsets = np.append(offsets, indices.shape[0]) + else: + # modify the last index of offsets to the end index + # however, pytorch doc says if `include_last_offset` is True, the size of offsets + # is equal to the number of bags + 1. The last element is the size of the input, + # or the ending index position of the last bag (sequence). + offsets[-1] = indices.shape[0] + + # separately reduce embeddings for different bags + reduced_embed = [] + len_offsets = len(offsets) + for i in range(len_offsets - 1): + if offsets[i] < offsets[i + 1]: + sliced_embed = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_embed_{i}", + embed, + 0, + int(offsets[i]), + int(offsets[i + 1]), + 1, + ) + reduced_sliced_embed = reduce_op( + name=f"{name}_{reduce_name}_{i}", + input_val=sliced_embed, + dim=0, + keepdim=True, + ) + reduced_embed.append(reduced_sliced_embed) + + out = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", reduced_embed, 0) + # out = reduce_op(input_val=embed, dim=1, keepdim=False) # Note: This implementation doesn't work for N-dim + + return out, None, None, None diff --git a/py/torch_tensorrt/dynamo/conversion/impl/grid.py b/py/torch_tensorrt/dynamo/conversion/impl/grid.py new file mode 100644 index 0000000000..63ff93b0c7 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/grid.py @@ -0,0 +1,45 @@ +from typing import Optional + +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + +# nearest, linear, cubic +GridSamplerInterpolationMode = { + 0: trt.InterpolationMode.NEAREST, + 1: trt.InterpolationMode.LINEAR, + 2: trt.InterpolationMode.CUBIC, +} + +# zeros, border, reflection +GridSamplerSampling = { + 0: trt.SampleMode.FILL, + 1: trt.SampleMode.CLAMP, + 2: trt.SampleMode.REFLECT, +} + + +def grid( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + grid: TRTTensor, + interpolation_mode: int, + padding_mode: int, + align_corners: bool, +) -> TRTTensor: + grid_layer = ctx.net.add_grid_sample(input, grid) + assert interpolation_mode in GridSamplerInterpolationMode + grid_layer.interpolation_mode = GridSamplerInterpolationMode.get( + interpolation_mode, None + ) + assert padding_mode in GridSamplerSampling + grid_layer.sample_mode = GridSamplerSampling.get(padding_mode, None) + grid_layer.align_corners = align_corners + set_layer_name(grid_layer, target, name + "_grid_layer", source_ir) + return grid_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py new file mode 100644 index 0000000000..3764667ffb --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -0,0 +1,205 @@ +from typing import Optional, Sequence, Union + +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.fx.converters.converter_utils import ( + has_dynamic_shape, + set_layer_name, +) +from torch_tensorrt.fx.types import TRTTensor + +""" +Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0. +Use ISliceLayer to pad the tensor, which supports new non-constant, reflects padding +mode and clamp, and supports padding output with dynamic shape. +""" + + +def constant_padNd( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + pad: Sequence[int], + value: Union[int, float] = 0, +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." + + rank = len(input.shape) + + if len(pad) // 2 > rank: + raise RuntimeError( + f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension." + ) + + start_list = [0] * rank + new_shape = list(input.shape) + + for i in range(0, len(pad) // 2): + start_list[-i - 1] = -pad[i * 2] + new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1] + + stride_list = [1] * rank + layer = ctx.net.add_slice( + input, + start=tuple(start_list), + shape=tuple(new_shape), + stride=tuple(stride_list), + ) + value_const = get_trt_tensor(ctx, value, f"{name}_value", input.dtype) + layer.set_input(4, value_const) + layer.mode = trt.SliceMode.FILL + + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + + +def reflection_padNd( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + padding: Sequence[int], +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." + + rank = len(input.shape) + + if len(padding) // 2 > rank: + raise RuntimeError( + f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension." + ) + + start_list = [0] * rank + new_shape = list(input.shape) + + for i in range(0, len(padding) // 2): + start_list[-i - 1] = -padding[i * 2] + new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1] + + stride_list = [1] * rank + layer = ctx.net.add_slice( + input, + start=tuple(start_list), + shape=tuple(new_shape), + stride=tuple(stride_list), + ) + layer.mode = trt.SliceMode.REFLECT + + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + + +def replication_padNd( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + padding: Sequence[int], +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." + + rank = len(input.shape) + + if len(padding) // 2 > rank: + raise RuntimeError( + f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension." + ) + + start_list = [0] * rank + new_shape = list(input.shape) + + for i in range(0, len(padding) // 2): + start_list[-i - 1] = -padding[i * 2] + new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1] + + stride_list = [1] * rank + layer = ctx.net.add_slice( + input, + start=tuple(start_list), + shape=tuple(new_shape), + stride=tuple(stride_list), + ) + layer.mode = trt.SliceMode.CLAMP + + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + + +def circular_padNd( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + pad: Sequence[int], +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." + + rank = len(input.shape) + + if len(pad) // 2 > rank: + raise RuntimeError( + f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension." + ) + + start_list = [0] * rank + new_shape = list(input.shape) + + for i in range(0, len(pad) // 2): + start_list[-i - 1] = -pad[i * 2] + new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1] + + stride_list = [1] * rank + layer = ctx.net.add_slice( + input, + start=tuple(start_list), + shape=tuple(new_shape), + stride=tuple(stride_list), + ) + layer.mode = trt.SliceMode.WRAP + + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + + +def pad( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + pad: Sequence[int], + mode: str = "constant", + value: Optional[float] = None, +) -> TRTTensor: + if mode == "constant": + return constant_padNd( + ctx, + target, + source_ir, + f"{name}_{mode}", + input, + pad, + value if value is not None else 0, + ) + elif mode == "reflect": + return reflection_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad) + elif mode == "replicate": + return replication_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad) + elif mode == "circular": + return circular_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad) + else: + raise RuntimeError( + f'We currently only support for `mode` in ["constant", "reflect", "replicate", "circular"], but got {mode}' + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py index 2fcd57a7f6..04f5596581 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py @@ -19,7 +19,7 @@ def amax( source_ir: Optional[SourceIR], name: str, input_val: TRTTensor, - dim: Union[int, Sequence[int]], + dim: Sequence[int] = [], keepdim: bool = False, ) -> TRTTensor: if (isinstance(input_val, TRTTensor)) and ( @@ -27,7 +27,7 @@ def amax( ): input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) - if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0): + if isinstance(dim, (tuple, list)) and len(dim) == 0: dim = tuple(range(len(input_val.shape))) layer = ctx.net.add_reduce( @@ -40,6 +40,33 @@ def amax( return layer.get_output(0) +def amin( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + dim: Sequence[int] = [], + keepdim: bool = False, +) -> TRTTensor: + if (isinstance(input_val, TRTTensor)) and ( + input_val.dtype == trt.int8 or input_val.dtype == trt.int32 + ): + input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) + + if isinstance(dim, (tuple, list)) and len(dim) == 0: + dim = tuple(range(len(input_val.shape))) + + layer = ctx.net.add_reduce( + input_val, + trt.ReduceOperation.MIN, + axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))), + keep_dims=keepdim, + ) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + + def sum( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 70f94cdca8..dc33129d24 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -3,6 +3,7 @@ import numpy as np import tensorrt as trt +import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -80,14 +81,20 @@ def index( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - index: Union[TRTTensor, Sequence[TRTTensor]], + index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] - # _LOGGER.debug(f"The index shape is {index.shape}") # check if the input is dynamic dynamic_shape = has_dynamic_shape(input.shape) - + # is_numpy is a flag to specify if all the indices are numpy or torchTensor. + # If any is not this flag will be set to False + _LOGGER.debug( + "Determining whether aten.index constant-index optimization can be invoked" + ) + is_numpy = all( + isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None + ) # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None @@ -95,8 +102,13 @@ def index( if ind is not None: _LOGGER.debug(f"Shape of {i} index is {ind.shape}") adv_indx_indices.append(i) - # torch.nn.parameter.Parameter=> torch.Tensor - ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}") + # torch.nn.parameter.Parameter=> numpy array + # numpy array is kept as numpy + # other cases are kept as TRTTensor + if is_numpy: + ind = to_numpy(ind) + else: + ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}") if last_index is not None: assert broadcastable( ind, last_index @@ -110,8 +122,9 @@ def index( set_layer_name(identity_layer, target, name + "_index_identity", source_ir) return identity_layer.get_output(0) elif len(tensor_indices) == 1: - # This case works - indices_tensor = tensor_indices[0] + indices_tensor = get_trt_tensor( + ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor" + ) index = adv_indx_indices[0] _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") gather_layer = ctx.net.add_gather(input, indices_tensor, index) @@ -150,6 +163,7 @@ def index( if i not in adv_indx_indices: new_order.append(i) _LOGGER.debug(f"The new transpose order is {new_order}") + transpose_layer.second_transpose = tuple(new_order) set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir) transpose_tensor = transpose_layer.get_output(0) @@ -175,47 +189,58 @@ def index( concat_tensor = concat_tensor_layer.get_output(0) reshape_layer = ctx.net.add_shuffle(transpose_tensor) - # check this reshape_layer.set_input(1, concat_tensor) flatten_tensor = reshape_layer.get_output(0) + _LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}") # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the # // j dimension of input x. - multiplier = get_trt_tensor( - ctx, - dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], - name + "_dim_last", - ) - cum_adv_index = tensor_indices[adv_indx_count - 1] - for i in range(adv_indx_count - 2, -1, -1): - adv_index = convert_binary_elementwise( - ctx, - target, - source_ir, - name + f"_index_intermediate_{i}", - trt.ElementWiseOperation.PROD, - multiplier, - tensor_indices[i], + if is_numpy: + multiplier = input_shape[adv_indx_indices[adv_indx_count - 1]] + cum_adv_index = tensor_indices[adv_indx_count - 1] + for i in range(adv_indx_count - 2, -1, -1): + adv_index = multiplier * tensor_indices[i] + cum_adv_index = cum_adv_index + adv_index + multiplier = multiplier * input_shape[adv_indx_indices[i]] + cum_adv_index = get_trt_tensor( + ctx, cum_adv_index, name + "_index_sum_intermediate" ) - cum_adv_index = convert_binary_elementwise( - ctx, - target, - source_ir, - name + f"_index_sum_intermediate_{i}", - trt.ElementWiseOperation.SUM, - cum_adv_index, - adv_index, - ) - multiplier = convert_binary_elementwise( + else: + multiplier = get_trt_tensor( ctx, - target, - source_ir, - name + f"_index_intermediate_xj_{i}", - trt.ElementWiseOperation.PROD, - multiplier, - dim_tensor_list[adv_indx_indices[i]], + dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], + name + "_dim_last", ) + cum_adv_index = tensor_indices[adv_indx_count - 1] + for i in range(adv_indx_count - 2, -1, -1): + adv_index = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_index_intermediate_{i}", + trt.ElementWiseOperation.PROD, + multiplier, + tensor_indices[i], + ) + cum_adv_index = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_index_sum_intermediate_{i}", + trt.ElementWiseOperation.SUM, + cum_adv_index, + adv_index, + ) + multiplier = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_index_intermediate_xj_{i}", + trt.ElementWiseOperation.PROD, + multiplier, + dim_tensor_list[adv_indx_indices[i]], + ) gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0) set_layer_name( @@ -238,7 +263,7 @@ def index( adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1 ): - _LOGGER.debug(f"The indices are continuous in this case") + _LOGGER.debug("The indices are continuous in this case") concat_tensor_reshape.append( get_trt_tensor(ctx, -1, name + "_dynamic_concat") ) @@ -262,7 +287,7 @@ def index( source_ir, ) unfold_tensor = regular_index_shuffle_layer.get_output(0) - _LOGGER.debug(f"The tensor is unfolded now") + _LOGGER.debug("The tensor is unfolded now") _LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}") # Transpose folded advanced indexed axis to its original location. @@ -317,7 +342,7 @@ def index( reshape_output = unfold_advanced_shuffle_layer.get_output(0) else: - _LOGGER.debug(f"The indices are not continuous in this case") + _LOGGER.debug("The indices are not continuous in this case") concat_final_tensor = [] concat_final_tensor.append(cum_adv_index_shape_tensor) for i in range(0, rank): diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index cd396afc3f..91ac4a7042 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, Sequence import numpy as np import tensorrt as trt @@ -197,3 +197,31 @@ def cumsum( set_layer_name(loop_output, target, f"{name}_loop_output", source_ir) loop_output.set_input(1, trip_limit) return loop_output.get_output(0) + + +def tile( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dims: Sequence[int], +) -> TRTTensor: + diff = len(dims) - len(input.shape) + if diff > 0: + # prepend 1 to input.shape + new_shape = (1,) * diff + tuple(input.shape) + input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_prepend_input_shape", input, new_shape + ) + elif diff < 0: + # prepend 1 to dims + dims = (1,) * -diff + tuple(dims) + + shapes = [i * j for i, j in zip(input.shape, dims)] + starts = [0] * len(dims) + strides = [1] * len(dims) + layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides)) + layer.mode = trt.SampleMode.WRAP + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/topk.py b/py/torch_tensorrt/dynamo/conversion/impl/topk.py new file mode 100644 index 0000000000..41f6f990f2 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/topk.py @@ -0,0 +1,136 @@ +from typing import Optional, Tuple, Union + +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + flatten_dims, + get_axes_for_reduce_op, + get_positive_dim, +) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def argmax_argmin( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + topk_option: trt.TopKOperation, + dim: Optional[int], + keep_dim: bool = False, +) -> TRTTensor: + if input.dtype == trt.int32: + input = cast_trt_tensor(ctx, input, trt.float32, name, target, source_ir) + + # Three different cases here: + # 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank + # 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2 + # 3. normal cases, no additional handlings + out = input + + if dim is None: + new_shape = (*flatten_dims(input, 0, -1), 1) + out = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_flatten", input, new_shape + ) + elif len(input.shape) == 1: + new_shape = (*input.shape, 1) + out = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_broadcast", input, new_shape + ) + + # Reduce over the flattened input if the dimension is None, otherwise the specified dimension + reduce_mask = get_axes_for_reduce_op( + get_positive_dim(dim if dim is not None else 0, len(out.shape)) + ) + + topk_layer = ctx.net.add_topk(out, topk_option, 1, reduce_mask) + set_layer_name(topk_layer, target, name, source_ir) + + out = topk_layer.get_output(1) + + if dim is None: + new_shape = ((1,) * len(input.shape)) if keep_dim else () + out = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_unflatten", out, new_shape + ) + elif len(input.shape) == 1: + out = impl.squeeze.squeeze( + ctx, + target, + source_ir, + f"{name}_squeeze", + out, + 1 if keep_dim else (0, 1), + ) + elif not keep_dim: + out = impl.squeeze.squeeze(ctx, target, source_ir, f"{name}_squeeze", out, dim) + + return out + + +def argmax( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Optional[int], + keep_dim: bool = False, +) -> TRTTensor: + return argmax_argmin( + ctx, target, source_ir, name, input, trt.TopKOperation.MAX, dim, keep_dim + ) + + +def argmin( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Optional[int], + keep_dim: bool = False, +) -> TRTTensor: + return argmax_argmin( + ctx, target, source_ir, name, input, trt.TopKOperation.MIN, dim, keep_dim + ) + + +def sort( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + descending: bool, + return_indices: bool = True, +) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]: + if descending: + topk_layer = ctx.net.add_topk( + input, + trt.TopKOperation.MAX, + input.shape[dim], + get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))), + ) + else: + topk_layer = ctx.net.add_topk( + input, + trt.TopKOperation.MIN, + input.shape[dim], + get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))), + ) + + set_layer_name(topk_layer, target, name, source_ir) + + if return_indices: + return topk_layer.get_output(0), topk_layer.get_output(1) + else: + return topk_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 58c5f6ff4a..9ed5d0636d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -1,10 +1,14 @@ from typing import Optional import tensorrt as trt +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + get_trt_tensor, +) from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.types import TRTTensor @@ -336,6 +340,18 @@ def logical_not( ) +def bitwise_not( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + return impl.unary.logical_not( + ctx, target, source_ir, f"{name}_logical_not", input_val + ) + + def sign( ctx: ConversionContext, target: Target, @@ -419,3 +435,27 @@ def erf( return convert_unary( ctx, target, source_ir, name, trt.UnaryOperation.ERF, input_val ) + + +def trunc( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + if input_val.dtype not in (trt.float16, trt.float32): + return impl.cast.to_copy( + ctx, + target, + source_ir, + f"{name}_copy", + input_val, + input_val.dtype, + force_layer=True, + ) + + dividend = get_trt_tensor(ctx, 1, f"{name}_dividend") + return impl.elementwise.trunc_div( + ctx, target, source_ir, f"{name}_trunc", input_val, dividend + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py new file mode 100644 index 0000000000..594bb4167c --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -0,0 +1,67 @@ +from typing import Optional, Sequence + +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def upsample( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + out_shape: Optional[Sequence[int]], + scale_factors: Optional[Sequence[float]], + resize_mode: str, + align_corners: bool, +) -> TRTTensor: + resize_layer = ctx.net.add_resize(input) + # output size calculation + # Pytorch assumes that one of out_shape/scale_factor is None + # Pytorch assumes that dimensions match for out_shape/scale factor + if out_shape is not None: + resize_layer.shape = list(input.shape)[:2] + list(out_shape) + elif scale_factors is not None: + resize_layer.scales = [1.0, 1.0] + list(scale_factors) + else: + raise RuntimeError( + "At least one of out_shape and scale_factors should be specified." + ) + + # interpolate mode + if resize_mode == "nearest" or None: + resize_layer.resize_mode = trt.ResizeMode.NEAREST + elif resize_mode == "bilinear": + resize_layer.resize_mode = trt.ResizeMode.LINEAR + if align_corners is None or not align_corners: + raise RuntimeError( + f"Interpolation works differently is align_corners is False for {resize_mode} mode in PyTorch and TensorRT." + ) + else: + raise RuntimeError( + f"Interpolation mode is {resize_mode} which is not supported by TensorRT." + ) + + if resize_mode == "nearest": + resize_layer.coordinate_transformation = ( + trt.ResizeCoordinateTransformation.ASYMMETRIC + ) + elif resize_mode == "bilinear": + # align corners + if align_corners is not None and align_corners: + resize_layer.coordinate_transformation = ( + trt.ResizeCoordinateTransformation.ALIGN_CORNERS + ) + else: + resize_layer.coordinate_transformation = ( + trt.ResizeCoordinateTransformation.ASYMMETRIC + ) + + set_layer_name(resize_layer, target, name, source_ir) + + out = resize_layer.get_output(0) + return out diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index 3a67c47fa3..f83e0e5008 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -2,6 +2,7 @@ import operator from typing import Dict, Sequence, Tuple, Union +import numpy as np import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -35,3 +36,14 @@ def generic_evaluator( f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}" ) return target(*args) + + +@dynamo_tensorrt_converter(torch.ops.aten.arange.start_step) +def aten_ops_arange_start_step( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return np.arange(*args) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 158523a7e1..981c80f9fa 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -105,22 +105,6 @@ def alias_replacement(x: torch.Tensor) -> torch.Tensor: return x -@register_torch_trt_decomposition( - torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS -) -def addmm_replacement( - input_: torch.Tensor, - mat1: torch.Tensor, - mat2: torch.Tensor, - *, - beta: int = 1, - alpha: int = 1, -) -> torch.Tensor: - return torch.add( - torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha) - ) - - @register_torch_trt_decomposition( torch.ops.aten.reciprocal.default, registry=TORCH_TRT_DECOMPOSITIONS ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 93ad4655b5..d6e12f5215 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -6,6 +6,7 @@ from .constant_folding import constant_fold from .fuse_prims_broadcast import fuse_prims_broadcast from .lower_efficient_attention import lower_efficient_attention +from .lower_linear import lower_linear from .pass_manager import DynamoPassManager from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output @@ -18,6 +19,7 @@ constant_fold, repair_input_as_output, lower_efficient_attention, + lower_linear, fuse_prims_broadcast, replace_max_pool_with_indices, view_to_reshape, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 65f6cea1ed..e9a628ea43 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -2,7 +2,7 @@ from typing import Any, Sequence import torch -from torch._inductor.freezing import ConstantFolder, replace_node_with_constant +from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py index 944b0788b0..6984a70254 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py @@ -5,7 +5,6 @@ import torch from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, - get_tensor_placeholders, ) logger = logging.getLogger(__name__) @@ -36,34 +35,13 @@ def efficient_attention_replacement() -> ( ): """Constructs the original and replacement functions for efficient attention""" - # Empty boilerplate function taking in three Tensors and returning one - def boilerplate( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - ... - - # Trace boilerplate function and extract placeholder and output nodes - orig = torch.fx.symbolic_trace(boilerplate) - q, k, v = get_tensor_placeholders(orig) - output = [node for node in orig.graph.nodes if node.op == "output"][0] - - # Graph types to replace are those which use the _scaled_dot_product_efficient_attention - # function and extract only the first element - with orig.graph.inserting_before(output): - att = orig.graph.call_function( - torch.ops.aten._scaled_dot_product_efficient_attention.default, - args=(q, k, v, None, False), + # Original graph + def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, k, v, None, False ) - out = orig.graph.call_function( - operator.getitem, - args=(att, 0), - ) - - # Assign the output of the graph to be the single getitem output - output.args = (out,) - - orig.graph.lint() - orig.recompile() + out = operator.getitem(outputs, 0) + return out # Replacement graph consists of the functional version of scaled_dot_product_attention def replacement( diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py new file mode 100644 index 0000000000..75ad067a3f --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py @@ -0,0 +1,47 @@ +import logging +from typing import Callable, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def lower_linear( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Replace aten.linear with an equivalent implementation which can be easily converted to TRT""" + orig, replacement = linear_replacement() + + if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after lowering linear:\n{gm.graph}") + + return gm + + +def linear_replacement() -> ( + Tuple[ + torch.fx.GraphModule, + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], + ] +): + """Constructs the original and replacement functions for linear""" + + # Original graph + def orig( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + W_T = torch.ops.aten.permute.default(weight, [1, 0]) + out = torch.ops.aten.addmm.default(bias, input, W_T) + return out + + # Replacement graph + def replacement( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + return torch.ops.aten.linear.default(input, weight, bias) + + return orig, replacement diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 5bdbb8919b..1027d8eac3 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -42,8 +42,10 @@ def is_node_supported( node_name = ConverterRegistry.qualified_name_or_str(node.target) if ( - node in CONVERTERS or node.op == "get_attr" - ) and node_name not in self.torch_executed_ops: + (node in CONVERTERS or node.op == "get_attr") + and node_name not in self.torch_executed_ops + and node.target not in self.torch_executed_ops + ): # If node is a proper, supported computational node, store the operator if not node.is_impure() and node.op != "get_attr": if node_name not in self.supported_operators: diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 092bdabfd0..6889bad22c 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -1,8 +1,9 @@ import logging -from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set +from typing import Collection, Dict, List, Mapping, Optional, Sequence import torch from torch.fx.graph_module import GraphModule +from torch.fx.node import Target from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.passes.operator_support import OperatorSupport, SupportDict from torch_tensorrt.dynamo._defaults import ( @@ -133,16 +134,14 @@ class TorchTensorRTOperatorSupport(OperatorSupport): # type: ignore[misc] def __init__( self, support_dict: Optional[SupportDict] = None, - torch_executed_ops: Optional[Set[str]] = None, + torch_executed_ops: Collection[Target] = set(), ): super().__init__(support_dict) # Initialize sets of supported/unsupported operators self.supported_operators: Dict[str, int] = {} self.unsupported_operators: Dict[str, int] = {} - self.torch_executed_ops: Set[str] = ( - torch_executed_ops if torch_executed_ops is not None else set() - ) + self.torch_executed_ops: Collection[Target] = torch_executed_ops def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node @@ -150,8 +149,10 @@ def is_node_supported( node_name = ConverterRegistry.qualified_name_or_str(node.target) if ( - node in CONVERTERS or node.op == "get_attr" - ) and node_name not in self.torch_executed_ops: + (node in CONVERTERS or node.op == "get_attr") + and node_name not in self.torch_executed_ops + and node.target not in self.torch_executed_ops + ): # If node is a proper, supported computational node, store the operator if not node.is_impure() and node.op != "get_attr": if node_name not in self.supported_operators: @@ -201,7 +202,7 @@ def partition( gm: torch.fx.GraphModule, verbose: bool = DEBUG, min_block_size: int = MIN_BLOCK_SIZE, - torch_executed_ops: Optional[Set[str]] = None, + torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, ) -> torch.fx.GraphModule: """Partition an FX GraphModule with aten ops into TRT engines @@ -211,16 +212,12 @@ def partition( gm: FX GraphModule to partition verbose: Bool representing whether to print operator support min_block_size: Minimum number of operators per TRT-Engine Block - torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage + torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage require_full_compilation: Whether to require that all operators be run in TRT Returns: torch.fx.GraphModule """ - supported_ops = TorchTensorRTOperatorSupport( - torch_executed_ops=torch_executed_ops - if torch_executed_ops is not None - else set() - ) + supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops) partitioner = TRTPartitioner( gm, supported_ops, diff --git a/py/torch_tensorrt/dynamo/tools/opset_coverage.py b/py/torch_tensorrt/dynamo/tools/opset_coverage.py index a18aecf087..bfa57d3ed8 100644 --- a/py/torch_tensorrt/dynamo/tools/opset_coverage.py +++ b/py/torch_tensorrt/dynamo/tools/opset_coverage.py @@ -45,6 +45,10 @@ class OpsetCoverage: Path(os.path.dirname(torchgen.__file__)) / "packaged/ATen/native/tags.yaml" ) +DYNAMO_REGISTRY_NAME = "Dynamo ATen Converters Registry" +FX_REGISTRY_NAME = "FX ATen Converters Registry" +FX_LEGACY_REGISTRY_NAME = "FX Legacy ATen Converters Registry" + def get_aten_ops() -> List[Tuple[str, str]]: parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH) @@ -137,13 +141,25 @@ def opset_coverage( _, registry_data = c_registry.get_all_converters_with_target( target, return_registry_info=True ) + if registry_data is not None: - if registry_data["Dynamo ATen Converters Registry"] >= 1: + if ( + DYNAMO_REGISTRY_NAME in registry_data + and registry_data[DYNAMO_REGISTRY_NAME] >= 1 + ): status = SupportStatus.CONVERTED support_count += 1 - elif registry_data["FX Legacy ATen Converters Registry"] >= 1: + elif ( + FX_REGISTRY_NAME in registry_data + and registry_data[FX_REGISTRY_NAME] >= 1 + ) or ( + FX_LEGACY_REGISTRY_NAME in registry_data + and registry_data[FX_LEGACY_REGISTRY_NAME] >= 1 + ): status = SupportStatus.LEGACY_CONVERTED legacy_count += 1 + else: + raise Exception(f"Op belongs to unknown registry: {registry_data}") support_status[target_str] = { "schema": f"{target_str.split('.')[0]}.{opset_schemas[target_str]}", diff --git a/pyproject.toml b/pyproject.toml index 4769c05192..3ae3c5fbd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ requires = [ "typing-extensions>=4.7.0", "future>=0.18.3", "tensorrt>=8.6,<8.7", - "torch>=2.1.0,<2.2.0", + "torch>=2.2.0,<2.3.0", "pybind11==2.6.2", "numpy", ] @@ -41,7 +41,7 @@ readme = {file = "py/README.md", content-type = "text/markdown"} requires-python = ">=3.8" keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligence", "ml", "machine learning", "dl", "deep learning", "compiler", "dynamo", "torchscript", "inference"] dependencies = [ - "torch>=2.1.0,<2.2.0", + "torch>=2.2.0,<2.3.0", "tensorrt>=8.6,<8.7", "packaging>=23", "numpy", @@ -50,7 +50,7 @@ dependencies = [ dynamic = ["version"] [project.optional-dependencies] -torchvision = ["torchvision>=0.16.0,<0.17.0"] +torchvision = ["torchvision>=0.17.0,<0.18.0"] [project.urls] Homepage = "https://pytorch.org/tensorrt" diff --git a/tests/core/conversion/converters/test_matrix_multiply.cpp b/tests/core/conversion/converters/test_matrix_multiply.cpp index 9c84ba22f6..50248f379a 100644 --- a/tests/core/conversion/converters/test_matrix_multiply.cpp +++ b/tests/core/conversion/converters/test_matrix_multiply.cpp @@ -21,9 +21,8 @@ TEST(Converters, ATenMMConvertsCorrectly) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); - auto trt = trt_results[0].reshape_as(jit_results[0]); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenMMWithDiffShapesConvertsCorrectly) { @@ -42,9 +41,131 @@ TEST(Converters, ATenMMWithDiffShapesConvertsCorrectly) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); - auto trt = trt_results[0].reshape_as(jit_results[0]); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM1d2dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {10}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {10, 1}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM1d3dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {10}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {2, 10, 8}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM1d4dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {10}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {2, 3, 10, 8}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM3d1dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {2, 10, 8}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {8}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM2d1dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {1, 10}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {10}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenMM4d1dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : Tensor = aten::matmul(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(0, 5, {2, 3, 10, 8}, {at::kCUDA}); + auto in2 = at::randint(0, 5, {8}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenBMMConvertsCorrectly) { @@ -63,9 +184,8 @@ TEST(Converters, ATenBMMConvertsCorrectly) { params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); - auto trt = trt_results[0].reshape_as(jit_results[0]); - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } TEST(Converters, ATenBADDBMMConvertsCorrectly) { diff --git a/tests/modules/custom_models.py b/tests/modules/custom_models.py index 327fbc3fb9..f57ab7634e 100644 --- a/tests/modules/custom_models.py +++ b/tests/modules/custom_models.py @@ -1,8 +1,9 @@ +from typing import Dict, List, Tuple + import torch import torch.nn as nn -from transformers import BertModel, BertTokenizer, BertConfig import torch.nn.functional as F -from typing import Tuple, List, Dict +from transformers import BertConfig, BertModel, BertTokenizer # Sample Pool Model (for testing plugin serialization) @@ -180,10 +181,12 @@ def BertModule(): num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, + use_cache=False, + output_attentions=False, + output_hidden_states=False, torchscript=True, ) - model = BertModel(config) + model = BertModel.from_pretrained(model_name, config=config) model.eval() - model = BertModel.from_pretrained(model_name, torchscript=True) traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) return traced_model diff --git a/tests/modules/requirements.txt b/tests/modules/requirements.txt index 51fcf95500..da63a6dad1 100644 --- a/tests/modules/requirements.txt +++ b/tests/modules/requirements.txt @@ -1,3 +1,3 @@ -timm==v0.9.2 -transformers==4.30.0 +timm +transformers torchvision diff --git a/tests/py/dynamo/backend/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py index db9520ccd5..3885627b5f 100644 --- a/tests/py/dynamo/backend/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -365,5 +365,39 @@ def forward(self, x, y): torch._dynamo.reset() +class TestDeconvolution(TestCase): + def test_ConvTranspose2d(self): + class Up(torch.nn.Module): + def __init__(self, in_channels, out_channels, upsample_stride): + super().__init__() + self.up = torch.nn.ConvTranspose2d( + in_channels, + out_channels, + upsample_stride, + stride=upsample_stride, + bias=False, + ) + + def forward(self, x): + return self.up(x) + + device = torch.device("cuda:0") + model = Up(64, 128, 2).to(device) + model.eval() + print(model) + + x = torch.rand((1, 64, 100, 100)).to(device) + model_opt = torch.compile( + model, + backend="torch_tensorrt", + options={ + "min_block_size": 1, + "debug": True, + }, + ) + with torch.no_grad(): + _ = model_opt(x) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_addmm_aten.py b/tests/py/dynamo/conversion/test_addmm_aten.py new file mode 100644 index 0000000000..6108d3ea6d --- /dev/null +++ b/tests/py/dynamo/conversion/test_addmm_aten.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestAddmmConverter(DispatchTestCase): + @parameterized.expand( + [ + ((2, 2), (2, 3), (3, 2)), + ((4, 6), (4, 5), (5, 6)), + ((2, 1), (2, 3), (3, 1)), + ((4, 1), (4, 1), (1, 1)), + ((1, 2), (1, 3), (3, 2)), + ] + ) + def test_addmm(self, input_shape, mat1_shape, mat2_shape): + class Addmm(nn.Module): + def forward(self, input, mat1, mat2): + return torch.ops.aten.addmm.default(input, mat1, mat2) + + inputs = [ + torch.randn(input_shape), + torch.randn(mat1_shape), + torch.randn(mat2_shape), + ] + + self.run_test( + Addmm(), + inputs, + ) + + @parameterized.expand( + [ + ((2, 2), (2, 3), (3, 2), 1.0, 1.0), + ((4, 6), (4, 5), (5, 6), 1.2, 0.8), + ((2, 1), (2, 3), (3, 1), 3, 2), + ((4, 1), (4, 1), (1, 1), 1, 1), + ((1, 2), (1, 3), (3, 2), 2, 1.0), + ((1, 2), (1, 3), (3, 2), 1, 2.0), + ] + ) + def test_addmm_scale(self, input_shape, mat1_shape, mat2_shape, beta, alpha): + class Addmm(nn.Module): + def forward(self, input, mat1, mat2): + return torch.ops.aten.addmm.default( + input, mat1, mat2, beta=beta, alpha=alpha + ) + + inputs = [ + torch.randn(input_shape), + torch.randn(mat1_shape), + torch.randn(mat2_shape), + ] + + self.run_test( + Addmm(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_amin_aten.py b/tests/py/dynamo/conversion/test_amin_aten.py new file mode 100644 index 0000000000..03ae9b6113 --- /dev/null +++ b/tests/py/dynamo/conversion/test_amin_aten.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestAminConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 2, 4), 1, True), + ((2, 3, 4, 5), 3, True), + ((2, 3, 4, 5), 2, False), + ((6, 7, 5, 4, 5), 4, False), + ((1, 5, 2, 1), -1, True), + ] + ) + def test_amin_dim_int_default(self, input_shape, dim, keep_dims): + class Amin(nn.Module): + def forward(self, x): + return torch.ops.aten.amin.default(x, dim, keep_dims) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Amin(), + inputs, + ) + + @parameterized.expand( + [ + ((1, 2, 4), [], True), + ((3, 2, 4), [1], True), + ((2, 1, 4, 5), [0, 3], True), + ((2, 3, 4, 5), [0, 1, 2, 3], False), + ((6, 7, 5, 4, 5), [1, 3, 4], False), + ] + ) + def test_amin_dim_tuple_default(self, input_shape, dim, keep_dims): + class Amin(nn.Module): + def forward(self, x): + return torch.ops.aten.amin.default(x, dim, keep_dims) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Amin(), + inputs, + ) + + @parameterized.expand( + [ + ((3, 2, 4), 1, True, torch.int, 0, 5), + ((2, 3, 4, 5), 3, True, torch.int, -10, 10), + ((2, 3, 4, 5), 2, False, torch.int32, -5, 0), + ((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5), + ((1, 5, 2, 1), -4, False, torch.int32, -5, 5), + ] + ) + def test_amin_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high): + class Amin(nn.Module): + def forward(self, x): + return torch.ops.aten.amin.default(x, dim, keep_dims) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + Amin(), + inputs, + check_dtype=False, + ) + + @parameterized.expand( + [ + ((1, 2, 4), [], True, torch.int, 0, 5), + ((3, 2, 4), [1], True, torch.int, 0, 5), + ((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10), + ((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0), + ((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5), + ((1, 5, 2, 1), [-3, -1], False, torch.int32, -5, 5), + ] + ) + def test_amin_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high): + class Amin(nn.Module): + def forward(self, x): + return torch.ops.aten.amin.default(x, dim, keep_dims) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + Amin(), + inputs, + check_dtype=False, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_arange_aten.py b/tests/py/dynamo/conversion/test_arange_aten.py new file mode 100644 index 0000000000..035b957865 --- /dev/null +++ b/tests/py/dynamo/conversion/test_arange_aten.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestArangeConverter(DispatchTestCase): + @parameterized.expand( + [ + (0, 5, 1), + (1, 5, 2), + (3, 5, 3), + (5, 0, -1), + (5, 1, -2), + (5, 3, -3), + ] + ) + def test_arange(self, start, end, step): + class Arange(nn.Module): + def forward(self, x): + return torch.ops.aten.arange.start_step(start, x.shape[0], step) + + inputs = [torch.randn(end, 1)] + self.run_test( + Arange(), + inputs, + use_dynamo_tracer=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_argmax_aten.py b/tests/py/dynamo/conversion/test_argmax_aten.py index bf469d0901..a3f9f67b95 100644 --- a/tests/py/dynamo/conversion/test_argmax_aten.py +++ b/tests/py/dynamo/conversion/test_argmax_aten.py @@ -11,11 +11,11 @@ class TestArgmaxConverter(DispatchTestCase): [ # input dimension == 1 ("dim_1_keep_dim_true", (3,), 0, True), - ("dim_1_keep_dim_true", (3,), 0, False), + ("dim_1_keep_dim_false", (3,), 0, False), # dim == None - ("dim_none", (3,), None, True), - ("dim_none", (3, 3), None, True), - ("dim_none", (3, 3, 3), None, False), + ("dim_1_none_true", (3,), None, True), + ("dim_2_none_true", (3, 3), None, True), + ("dim_3_none_false", (3, 3, 3), None, False), # # common cases ("dim_1_keep_dim_true", (3, 3), 1, True), ("dim_1_keep_dim_false", (3, 3), 1, False), diff --git a/tests/py/dynamo/conversion/test_argmin_aten.py b/tests/py/dynamo/conversion/test_argmin_aten.py new file mode 100644 index 0000000000..f06284f394 --- /dev/null +++ b/tests/py/dynamo/conversion/test_argmin_aten.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestArgminConverter(DispatchTestCase): + @parameterized.expand( + [ + # input dimension == 1 + ("dim_1_keep_dim_true", (3,), 0, True), + ("dim_1_keep_dim_false", (3,), 0, False), + # dim == None + ("dim_1_none_true", (3,), None, True), + ("dim_2_none_true", (3, 3), None, True), + ("dim_3_none_false", (3, 3, 3), None, False), + # # common cases + ("dim_1_keep_dim_true", (3, 3), 1, True), + ("dim_1_keep_dim_false", (3, 3), 1, False), + ("dim_0_keep_dim_true", (4, 4, 4), 0, True), + ("dim_0_keep_dim_false", (4, 4, 4), 0, False), + ("dim_negative_keep_dim_true", (1, 2, 3), -1, True), + ] + ) + def test_argmin(self, _, input_shape, dim, keep_dim): + class ArgMin(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.argmin.default(input, dim, keep_dim) + + input = [torch.randn(*input_shape)] + + self.run_test(ArgMin(), input) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_and_aten.py b/tests/py/dynamo/conversion/test_bitwise_and_aten.py new file mode 100644 index 0000000000..5c2a78a18a --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_and_aten.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseAndConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_and_tensor(self, _, shape): + class bitwise_and(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_and(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_and_scalar(self, _, shape, scalar): + class bitwise_and(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_and.Scalar(tensor, scalar) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_and(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_and_scalar_tensor(self, _, shape, scalar): + class bitwise_and(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_and.Scalar_Tensor(scalar, tensor) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_and(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_not_aten.py b/tests/py/dynamo/conversion/test_bitwise_not_aten.py new file mode 100644 index 0000000000..6dd512ef16 --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_not_aten.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseNotConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_not_tensor(self, _, shape): + class bitwise_not(nn.Module): + def forward(self, val): + return torch.ops.aten.bitwise_not.default(val) + + inputs = [ + torch.randint(0, 2, shape, dtype=torch.bool), + ] + self.run_test( + bitwise_not(), + inputs, + enable_passes=True, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_or_aten.py b/tests/py/dynamo/conversion/test_bitwise_or_aten.py new file mode 100644 index 0000000000..b5e0200734 --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_or_aten.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseOrConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_or_tensor(self, _, shape): + class bitwise_or(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_or.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_or(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_or_scalar(self, _, shape, scalar): + class bitwise_or(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_or.Scalar(tensor, scalar) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_or(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_or_scalar_tensor(self, _, shape, scalar): + class bitwise_or(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_or.Scalar_Tensor(scalar, tensor) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_or(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_bitwise_xor_aten.py b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py new file mode 100644 index 0000000000..8c1a8136ef --- /dev/null +++ b/tests/py/dynamo/conversion/test_bitwise_xor_aten.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBitwiseXorConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_bitwise_xor_tensor(self, _, shape): + class bitwise_xor(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.bitwise_xor.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_xor(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_xor_scalar(self, _, shape, scalar): + class bitwise_xor(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_xor.Scalar(tensor, scalar) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_xor(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ("2d", (5, 3), True), + ("3d", (5, 3, 2), False), + ] + ) + def test_bitwise_xor_scalar_tensor(self, _, shape, scalar): + class bitwise_xor(nn.Module): + def forward(self, tensor): + return torch.ops.aten.bitwise_xor.Scalar_Tensor(scalar, tensor) + + inputs = [ + torch.randint(0, 2, shape, dtype=bool), + ] + self.run_test( + bitwise_xor(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_casts.py b/tests/py/dynamo/conversion/test_casts.py index f17eb7b1d4..c067a0b9ad 100644 --- a/tests/py/dynamo/conversion/test_casts.py +++ b/tests/py/dynamo/conversion/test_casts.py @@ -75,6 +75,21 @@ def forward(self, x): inputs, ) + def test_to_copy_multiple_returns(self): + class ToCopyReturns(nn.Module): + def forward(self, x): + x_1 = x + 1 + y = torch.ops.aten._to_copy.default(x_1, dtype=torch.float) + z = torch.ops.aten._to_copy.default(x_1, dtype=torch.float) + return y, z + + inputs = [torch.rand((1, 3, 10))] + self.run_test( + ToCopyReturns(), + inputs, + precision=torch.float, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_chunk_aten.py b/tests/py/dynamo/conversion/test_chunk_aten.py new file mode 100644 index 0000000000..1812165b43 --- /dev/null +++ b/tests/py/dynamo/conversion/test_chunk_aten.py @@ -0,0 +1,82 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestChunkConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1,), 3, 0), + ((3,), 3, 0), + ((4,), 3, 0), + ((6,), 3, 0), + ((3,), 1, -1), + ((3,), 3, -1), + ((3,), 4, -1), + ] + ) + def test_chunk_1D(self, shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input = [torch.randn(shape)] + self.run_test( + TestChunk(), + input, + ) + + @parameterized.expand( + [ + ((3, 4), 1, 0), + ((3, 4), 3, 0), + ((3, 4), 4, 0), + ((3, 4), 2, -2), + ((3, 4), 6, -2), + ((3, 4), 3, 1), + ((3, 4), 4, 1), + ((3, 4), 5, -1), + ] + ) + def test_chunk_2D(self, shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input = [torch.randn(shape)] + self.run_test( + TestChunk(), + input, + ) + + @parameterized.expand( + [ + ((3, 4, 2), 1, 0), + ((3, 4, 2), 3, -3), + ((3, 4, 2), 3, 1), + ((3, 4, 2), 4, 1), + ((3, 4, 2), 6, -2), + ((3, 4, 2), 1, 2), + ((3, 4, 2), 3, -1), + ((3, 4, 2), 4, -1), + ] + ) + def test_chunk_3D(self, shape, chunks, dim): + class TestChunk(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.chunk.default(input, chunks, dim) + return out + + input = [torch.randn(shape)] + self.run_test( + TestChunk(), + input, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_clamp_aten.py b/tests/py/dynamo/conversion/test_clamp_aten.py index fcee7bfa3c..0bad9ee350 100644 --- a/tests/py/dynamo/conversion/test_clamp_aten.py +++ b/tests/py/dynamo/conversion/test_clamp_aten.py @@ -49,7 +49,7 @@ def forward(self, x): class TestScalarModule(torch.nn.Module): def forward(self, x): - y = torch.ops.aten.mean.default(x) + y = torch.ops.aten.mean.dim(x, None, True) return torch.ops.aten.clamp.default(y, min, max) input_specs = [ @@ -63,6 +63,30 @@ def forward(self, x): self.run_test_with_dynamic_shape(TestModule(), input_specs) self.run_test_with_dynamic_shape(TestScalarModule(), input_specs) + @parameterized.expand( + [ + param("default", min=-1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)), + param("min", min=0.5 * torch.randn(3, 4)), + param("max", max=0.5 * torch.randn(3, 4)), + param( + "minBiggerThanMax", min=1 * torch.randn(3, 4), max=0 * torch.randn(3, 4) + ), + param("float32Boundary", min=-3.4028234663852886e38 * torch.randn(3, 4)), + ] + ) + def test_clamp_tensor( + self, + test_name, + min=None, + max=None, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.clamp.Tensor(x, min, max) + + inputs = [torch.randn(3, 4)] + self.run_test(TestModule(), inputs) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_clip_aten.py b/tests/py/dynamo/conversion/test_clip_aten.py index a3819fb4dd..447e2c9e17 100644 --- a/tests/py/dynamo/conversion/test_clip_aten.py +++ b/tests/py/dynamo/conversion/test_clip_aten.py @@ -19,11 +19,38 @@ class TestClipConverter(DispatchTestCase): def test_clip(self, test_name, min=None, max=None): class TestModule(torch.nn.Module): def forward(self, x): - return torch.ops.aten.clamp.default(x, min, max) + return torch.ops.aten.clip.default(x, min, max) inputs = [torch.randn(3, 4)] self.run_test(TestModule(), inputs) + @parameterized.expand( + [ + param( + "defaultInt32", + min=torch.tensor(-1, dtype=torch.int32), + max=torch.tensor(0, dtype=torch.int32), + ), + param( + "defaultFloat32", + min=torch.tensor(0.5, dtype=torch.float32), + max=torch.tensor(1.0, dtype=torch.float32), + ), + param( + "minBiggerThanMax", + min=torch.tensor(1.0, dtype=torch.float32), + max=torch.tensor(0, dtype=torch.int32), + ), + ] + ) + def test_clip(self, test_name, min=None, max=None): + class TestModule(torch.nn.Module): + def forward(self, x, min, max): + return torch.ops.aten.clip.Tensor(x, min, max) + + inputs = [torch.randn(3, 4), min, max] + self.run_test(TestModule(), inputs) + @parameterized.expand( [ param("default", min=-1, max=0), @@ -37,12 +64,12 @@ def test_clip_with_dynamic_shape_four_dimensions( ): class TestModule(torch.nn.Module): def forward(self, x): - return torch.ops.aten.clamp.default(x, min, max) + return torch.ops.aten.clip.default(x, min, max) class TestScalarModule(torch.nn.Module): def forward(self, x): - y = torch.ops.aten.mean.default(x) - return torch.ops.aten.clamp.default(y, min, max) + y = torch.ops.aten.mean.dim(x, None, True) + return torch.ops.aten.clip.default(y, min, max) input_specs = [ Input( diff --git a/tests/py/dynamo/conversion/test_copy_aten.py b/tests/py/dynamo/conversion/test_copy_aten.py new file mode 100644 index 0000000000..1acb94daf6 --- /dev/null +++ b/tests/py/dynamo/conversion/test_copy_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestCopyConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3,), (3,), False), + ((1, 10), (1, 10), False), + ((2, 3, 4), (2, 3, 4), True), + ((2, 3, 4, 5), (2, 3, 4, 5), True), + ] + ) + def test_copy_float(self, input_shape, src_shape, non_blocking): + class Copy(nn.Module): + def forward(self, input, src): + return torch.ops.aten.copy.default(input, src, non_blocking) + + inputs = [torch.randn(input_shape), torch.randn(src_shape)] + self.run_test( + Copy(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py new file mode 100644 index 0000000000..6d7b05f0e1 --- /dev/null +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -0,0 +1,141 @@ +import torch +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestEmbeddingBagConverter(DispatchTestCase): + @parameterized.expand( + [ + # 1D input + param( + test_name="1d_indices_1", + weight=torch.randn((10, 3), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32), + offsets=torch.tensor([0, 3], dtype=torch.int32), + scale_grad_by_freq=False, + mode=1, + sparse=False, + per_sample_weights=None, + include_last_offset=True, + padding_idx=-1, + ), + param( + test_name="1d_indices_2", + weight=torch.randn((10, 3), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32), + offsets=torch.tensor([0, 5], dtype=torch.int32), + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=torch.randn((6,)), + include_last_offset=False, + padding_idx=-1, + ), + param( + test_name="1d_indices_3", + weight=torch.randn((10, 3), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32), + offsets=torch.tensor([0, 2, 4], dtype=torch.int32), + scale_grad_by_freq=False, + mode=2, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, + ), + # 2D input + # param( + # test_name="2d_indices_1", + # weight=torch.randn((5, 10), dtype=torch.float32), + # indices=torch.tensor([[3, 1], [4, 3]], dtype=torch.int32), + # offsets=torch.tensor([0, 1], dtype=torch.int32), + # scale_grad_by_freq=False, + # mode=0, + # sparse=False, + # per_sample_weights=torch.randn((4,)), + # include_last_offset=False, + # padding_idx=-1, + # ), + # param( + # test_name="2d_indices_3", + # weight=torch.tensor([ + # [0.0, 0.0, 0.0], + # [1.0, 1.0, 1.0], + # [2.0, 2.0, 2.0], + # [3.0, 3.0, 3.0], + # [4.0, 4.0, 4.0], + # [5.0, 5.0, 5.0], + # ], dtype=torch.float32), + # indices=torch.tensor([[0, 2, 1], [3, 5, 4]], dtype=torch.int32), + # offsets=torch.tensor([0, 1], dtype=torch.int32), + # scale_grad_by_freq=False, + # mode=2, + # sparse=False, + # per_sample_weights=None, + # include_last_offset=False, + # padding_idx=-1, + # ), + # param( + # test_name="2d_indices_2", + # weight=torch.randn((5, 5), dtype=torch.float32), + # indices=torch.tensor([[3, 1, 2], [4, 2, 3]], dtype=torch.int32), + # offsets=torch.tensor([0, 2], dtype=torch.int32), + # scale_grad_by_freq=False, + # mode=1, + # sparse=False, + # per_sample_weights=None, + # include_last_offset=False, + # padding_idx=-1, + # ), + # param( + # test_name="2d_indices_2", + # weight=torch.randn((5, 10), dtype=torch.float32), + # indices=torch.tensor([[3, 1, 2, 4], [4, 1, 3, 1]], dtype=torch.int32), + # offsets=torch.tensor([0, 2], dtype=torch.int32), + # scale_grad_by_freq=False, + # mode=0, + # sparse=False, + # per_sample_weights=torch.randn((8,)), + # include_last_offset=True, + # padding_idx=-1, + # ), + ] + ) + def test_embedding_bag( + self, + test_name, + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ): + class TestEmbeddingBag(torch.nn.Module): + def forward(self, weight, indices): + return torch.ops.aten._embedding_bag.default( + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + )[0] + + self.run_test( + TestEmbeddingBag(), + inputs=[weight, indices], + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_eq_aten.py b/tests/py/dynamo/conversion/test_eq_aten.py new file mode 100644 index 0000000000..17a372182c --- /dev/null +++ b/tests/py/dynamo/conversion/test_eq_aten.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestEqualConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_eq_tensor(self, _, shape): + class eq(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.eq.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] + self.run_test( + eq(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_eq_tensor_scalar(self, _, shape, scalar): + class eq(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.eq.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + eq(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_eq_scalar(self, _, shape, scalar): + class eq(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.eq.Scalar(lhs_val, scalar) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + eq(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_ge_aten.py b/tests/py/dynamo/conversion/test_ge_aten.py new file mode 100644 index 0000000000..6b1ee6d440 --- /dev/null +++ b/tests/py/dynamo/conversion/test_ge_aten.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestGtConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_ge_tensor(self, _, shape): + class ge(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.ge.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] + self.run_test( + ge(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ge_tensor_scalar(self, _, shape, scalar): + class ge(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ge.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ge(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ge_scalar(self, _, shape, scalar): + class ge(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ge.Scalar(lhs_val, scalar) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ge(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_grid_aten.py b/tests/py/dynamo/conversion/test_grid_aten.py new file mode 100644 index 0000000000..bf3ad6c8cc --- /dev/null +++ b/tests/py/dynamo/conversion/test_grid_aten.py @@ -0,0 +1,150 @@ +import pytest +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + +grid_sampler_ops = [ + ( + "input_grid_interpolation_nearest_sample_fill", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_clamp", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_reflect", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_fill", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_clamp", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_reflect", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_fill", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_clamp", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_reflect", + (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_fill_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_clamp_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_nearest_sample_reflect_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_fill_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_clamp_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_linear_sample_reflect_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_fill_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_clamp_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), + ( + "input_grid_interpolation_cubic_sample_reflect_2d", + (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), + [1, 1, 5, 5], + [1, 5, 2, 2], + ), +] + + +class TestGridConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + grid_sampler_op[0], + grid_sampler_op[1], + grid_sampler_op[2], + grid_sampler_op[3], + ) + for grid_sampler_op in grid_sampler_ops + ] + ) + def test_grid(self, _, op, input_shape, dim_shape): + class TestModule(nn.Module): + def __init__(self, grid_sampler_op): + super().__init__() + self.grid_sampler_op = grid_sampler_op + + def forward(self, x): + grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) + return self.grid_sampler_op(x, grid) + + inputs = [torch.randn(input_shape, dtype=torch.float32)] + grid_model = TestModule(op) + self.run_test(grid_model, inputs) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_gt_aten.py b/tests/py/dynamo/conversion/test_gt_aten.py new file mode 100644 index 0000000000..8d9ae24f80 --- /dev/null +++ b/tests/py/dynamo/conversion/test_gt_aten.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestGtConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_gt_tensor(self, _, shape): + class gt(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.gt.Tensor(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + gt(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_gt_tensor_scalar(self, _, shape, scalar): + class gt(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.gt.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randn(shape)] + self.run_test( + gt(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_gt_scalar(self, _, shape, scalar): + class gt(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.gt.Scalar(lhs_val, scalar) + + inputs = [torch.randn(shape)] + self.run_test( + gt(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 393eb53c63..88db7f0817 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -27,6 +27,21 @@ def forward(self, x): input, ) + def test_index_zero_two_dim_ITensor(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2) + index0 = torch.randint(0, 1, (1, 1)) + index0 = index0.to(torch.int32) + self.run_test( + TestModule(), + [input, index0], + ) + def test_index_zero_index_three_dim(self): class TestModule(nn.Module): def __init__(self): @@ -44,6 +59,18 @@ def forward(self, x): input, ) + def test_index_zero_index_three_dim_ITensor(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0, None] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2, 2) + index0 = torch.randint(0, 1, (1, 1)) + index0 = index0.to(torch.int32) + self.run_test(TestModule(), [input, index0]) + def test_index_zero_index_one_index_two_three_dim(self): class TestModule(nn.Module): def __init__(self): diff --git a/tests/py/dynamo/conversion/test_le_aten.py b/tests/py/dynamo/conversion/test_le_aten.py new file mode 100644 index 0000000000..373384c6f9 --- /dev/null +++ b/tests/py/dynamo/conversion/test_le_aten.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestLeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_le_tensor(self, _, shape): + class le(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.lt.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] + self.run_test( + le(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_le_tensor_scalar(self, _, shape, scalar): + class le(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.lt.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + le(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_le_scalar(self, _, shape, scalar): + class le(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.lt.Scalar(lhs_val, scalar) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + le(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_lt_aten.py b/tests/py/dynamo/conversion/test_lt_aten.py new file mode 100644 index 0000000000..89cb7f42c5 --- /dev/null +++ b/tests/py/dynamo/conversion/test_lt_aten.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestLtConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_lt_tensor(self, _, shape): + class lt(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.lt.Tensor(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + lt(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_lt_tensor_scalar(self, _, shape, scalar): + class lt(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.lt.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randn(shape)] + self.run_test( + lt(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (2, 1), 1), + ("3d", (2, 1, 2), 2.0), + ] + ) + def test_lt_scalar(self, _, shape, scalar): + class lt(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.lt.Scalar(lhs_val, scalar) + + inputs = [torch.randn(shape)] + self.run_test( + lt(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_ne_aten.py b/tests/py/dynamo/conversion/test_ne_aten.py new file mode 100644 index 0000000000..2450ac0945 --- /dev/null +++ b/tests/py/dynamo/conversion/test_ne_aten.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestNotEqualConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d", (5, 3)), + ("3d", (5, 3, 2)), + ] + ) + def test_ne_tensor(self, _, shape): + class ne(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.ne.Tensor(lhs_val, rhs_val) + + inputs = [ + torch.randint(0, 3, shape, dtype=torch.int32), + torch.randint(0, 3, shape, dtype=torch.int32), + ] + self.run_test( + ne(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ne_tensor_scalar(self, _, shape, scalar): + class ne(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ne.Tensor(lhs_val, torch.tensor(scalar)) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ne(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ("2d", (5, 3), 1), + ("3d", (5, 3, 2), 2.0), + ] + ) + def test_ne_scalar(self, _, shape, scalar): + class ne(nn.Module): + def forward(self, lhs_val): + return torch.ops.aten.ne.Scalar(lhs_val, scalar) + + inputs = [torch.randint(0, 3, shape, dtype=torch.int32)] + self.run_test( + ne(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_pad_aten.py b/tests/py/dynamo/conversion/test_pad_aten.py new file mode 100644 index 0000000000..2803736ad0 --- /dev/null +++ b/tests/py/dynamo/conversion/test_pad_aten.py @@ -0,0 +1,241 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestConstantPadConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1, 2), (1, 1), 0), + ((2, 1), (2, 1), 1), + ((3, 4, 2), (1, 2), 2), + ((3, 4, 2), (1, 2, 3, 1, 2, 3), 0), + ((3, 3, 4, 2), (1, 2, 3, 4), 0), + ((3, 3, 4, 2), (1, 2, 3, 4), 2), + ((3, 3, 4, 2, 1), (1, 2, 3, 4, 5, 1, 2, 3, 4, 5), 0), + ((3, 3, 4, 2, 1, 2), (1, 2, 3, 4, 1, 2, 3, 4), 4), + ] + ) + def test_constant_pad(self, shape, pad, value): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.constant_pad_nd.default(input, pad, value) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + +class TestReflectionPadConverter(DispatchTestCase): + @parameterized.expand( + [ + # Per pytorch doc, the input should be 2D or 3D + ((3, 3), (1, 1)), + ((3, 3), (2, 2)), + ((2, 2, 2), (1, 1)), + ((2, 2, 4), (2, 3)), + ] + ) + def test_reflection_pad1d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.reflection_pad1d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 3D or 4D + ((2, 2, 2), (1, 1, 1, 1)), + ((1, 2, 4), (2, 2, 1, 1)), + ((2, 2, 3, 3), (1, 1, 2, 2)), + ((2, 3, 4, 5), (4, 3, 0, 1)), + ] + ) + def test_reflection_pad2d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.reflection_pad2d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 4D or 5D + ((2, 2, 2, 2), (1, 1, 1, 1, 1, 1)), + ((1, 2, 3, 4), (3, 2, 2, 1, 1, 1)), + ((2, 2, 3, 3, 4), (3, 3, 2, 1, 1, 2)), + ((2, 3, 4, 5, 6), (4, 3, 2, 1, 1, 0)), + ] + ) + def test_reflection_pad3d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.reflection_pad3d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + +class TestReplicationPadConverter(DispatchTestCase): + @parameterized.expand( + [ + # Per pytorch doc, the input should be 2D or 3D + ((3, 3), (1, 1)), + ((3, 3), (2, 2)), + ((2, 2, 2), (1, 1)), + ((2, 2, 4), (2, 3)), + ] + ) + def test_replication_pad1d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.replication_pad1d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 3D or 4D + ((2, 2, 2), (1, 1, 1, 1)), + ((1, 2, 4), (2, 2, 1, 1)), + ((2, 2, 3, 3), (1, 1, 2, 2)), + ((2, 3, 4, 5), (4, 3, 0, 1)), + ] + ) + def test_replication_pad2d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.replication_pad2d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 4D or 5D + ((2, 2, 2, 2), (1, 1, 1, 1, 1, 1)), + ((1, 2, 3, 4), (3, 2, 2, 1, 1, 1)), + ((2, 2, 3, 3, 4), (3, 3, 2, 1, 1, 2)), + ((2, 3, 4, 5, 6), (4, 3, 2, 1, 1, 0)), + ] + ) + def test_replication_pad3d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.replication_pad3d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + +class TestCircularPadConverter(DispatchTestCase): + @parameterized.expand( + [ + # Per pytorch doc, the input should be 2D or 3D + ((3, 3), (1, 1)), + ((3, 3), (2, 2)), + ((2, 2, 2), (1, 1)), + ((2, 2, 4), (2, 3)), + ] + ) + def test_circular_pad1d(self, shape, pad): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten._pad_circular.default(input, pad) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 3D or 4D + ((2, 2, 2), (1, 1, 1, 1)), + ((1, 2, 4), (2, 2, 1, 1)), + ((2, 2, 3, 3), (1, 1, 2, 2)), + ((2, 3, 4, 5), (4, 3, 0, 1)), + ] + ) + def test_circular_pad2d(self, shape, pad): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten._pad_circular.default(input, pad) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 4D or 5D + ((2, 2, 2, 2), (1, 1, 1, 1, 1, 1)), + ((1, 2, 3, 4), (3, 2, 2, 1, 1, 1)), + ((2, 2, 3, 3, 4), (3, 3, 2, 1, 1, 2)), + ((2, 3, 4, 5, 6), (4, 3, 2, 1, 1, 0)), + ] + ) + def test_circular_pad3d(self, shape, pad): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten._pad_circular.default(input, pad) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + +class TestPadConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 3), (2, 2), "constant"), + ((2, 2, 4), (2, 3, 1, 0), "reflect"), + ((1, 2, 3, 4), (3, 2, 2, 1, 1, 1), "replicate"), + ((2, 3, 4, 5), (3, 2, 1, 0), "circular"), + ] + ) + def test_pad(self, shape, pad, mode, value=None): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.pad.default(input, pad, mode, value) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_sort_aten.py b/tests/py/dynamo/conversion/test_sort_aten.py new file mode 100644 index 0000000000..8bb9bc214e --- /dev/null +++ b/tests/py/dynamo/conversion/test_sort_aten.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestSortConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 2, 4), 0, True), + ((2, 3, 4, 5), 1, True), + ((2, 3, 4, 5), 2, False), + ((6, 7, 5, 4, 5), 4, False), + ((1, 5, 2, 1), -1, True), + ((1, 2, 5, 3), -2, False), + ((6, 2, 1, 3), -4, True), + ] + ) + def test_sort(self, input_shape, dim, descending): + class Sort(nn.Module): + def forward(self, x): + return torch.ops.aten.sort.default(x, dim, descending) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Sort(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_tile_aten.py b/tests/py/dynamo/conversion/test_tile_aten.py new file mode 100644 index 0000000000..5a7e98aa7d --- /dev/null +++ b/tests/py/dynamo/conversion/test_tile_aten.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestTileConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3,), (1,)), + ((3,), (0,)), + ((3,), (2,)), + ((2,), (2, 2)), + ((2,), (0, 2)), + ] + ) + def test_tile_1D(self, shape, dims): + class Tile(nn.Module): + def forward(self, x): + return torch.ops.aten.tile.default(x, dims) + + inputs = [torch.randn(shape)] + self.run_test( + Tile(), + inputs, + ) + + @parameterized.expand( + [ + ((3, 1), (0,)), + ((3, 1), (2,)), + ((2, 3), (2, 2)), + ((2, 3), (1, 0)), + ((2, 3), (0, 2)), + ((2, 3), (4, 2, 3)), + ((2, 3), (0, 0, 3)), + ((2, 3), (4, 2, 3, 1, 2)), + ] + ) + def test_tile_2D(self, shape, dims): + class Tile(nn.Module): + def forward(self, x): + return torch.ops.aten.tile.default(x, dims) + + inputs = [torch.randn(shape)] + self.run_test( + Tile(), + inputs, + ) + + @parameterized.expand( + [ + ((4, 2, 3), (2,)), + ((4, 2, 3), (1, 2)), + ((1, 2, 3), (2, 3)), + ((1, 2, 3), (2, 3, 4)), + ((1, 2, 3), (2, 3, 4, 5)), + ] + ) + def test_tile_3D(self, shape, dims): + class Tile(nn.Module): + def forward(self, x): + return torch.ops.aten.tile.default(x, dims) + + inputs = [torch.randn(shape)] + self.run_test( + Tile(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_trunc_aten.py b/tests/py/dynamo/conversion/test_trunc_aten.py new file mode 100644 index 0000000000..979ced17e2 --- /dev/null +++ b/tests/py/dynamo/conversion/test_trunc_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestTruncConverter(DispatchTestCase): + @parameterized.expand( + [ + ((10,),), + ((1, 20),), + ((2, 3, 4),), + ((2, 3, 4, 5),), + ] + ) + def test_trunc_float(self, shape): + class Trunc(nn.Module): + def forward(self, input): + return torch.ops.aten.trunc.default(input) + + inputs = [torch.randn(shape)] + self.run_test( + Trunc(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ((10,),), + ((1, 20),), + ((2, 3, 4),), + ((2, 3, 4, 5),), + ] + ) + def test_trunc_int(self, shape): + class Trunc(nn.Module): + def forward(self, input): + return torch.ops.aten.trunc.default(input) + + inputs = [torch.randint(-10, 10, shape, dtype=torch.int32)] + self.run_test( + Trunc(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/conversion/test_upsample.py b/tests/py/dynamo/conversion/test_upsample.py new file mode 100644 index 0000000000..448b3afb84 --- /dev/null +++ b/tests/py/dynamo/conversion/test_upsample.py @@ -0,0 +1,97 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestUpsampleConverter(DispatchTestCase): + # test case for nearest upsample, using output_size, scale_factors is disabled here + @parameterized.expand( + [ + ("upsample_nearest2d.vec_outshape_0", (2, 2), (4, 4)), + ("upsample_nearest2d.vec_outshape_1", (2, 2), (5, 5)), + ] + ) + def test_upsample_nearest_output_shape(self, _, input_shape, output_shape): + class Upsample(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.upsample_nearest2d.vec(input, output_shape, None) + + input = [torch.randn([1, 1] + list(input_shape))] + self.run_test(Upsample(), input) + + # test case for nearest upsample, using scale_factors, output_size is disabled here + @parameterized.expand( + [ + ("upsample_nearest2d.vec_scale_0", (2, 2), (2, 2)), + ("upsample_nearest2d.vec_scale_1", (2, 2), (1.5, 1.5)), + ] + ) + def test_upsample_nearest_scale_factor(self, _, input_shape, scale_factor): + class Upsample(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.upsample_nearest2d.vec(input, None, scale_factor) + + input = [torch.randn([1, 1] + list(input_shape))] + self.run_test(Upsample(), input) + + # test case for bilinear upsample, using output_size, scale_factors is disabled here + @parameterized.expand( + [ + ("upsample_bilinear2d.vec_outshape_0", (2, 2), (4, 4), True), + ("upsample_bilinear2d.vec_outshape_1", (2, 2), (5, 5), True), + ] + ) + def test_upsample_bilinear_output_shape( + self, _, input_shape, output_shape, align_corners + ): + class Upsample(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.upsample_bilinear2d.vec( + input, + output_shape, + align_corners, + None, + ) + + input = [torch.randn([1, 1] + list(input_shape))] + self.run_test(Upsample(), input) + + # test case for bilinear upsample, using scale_factors, output_shape is disabled here + @parameterized.expand( + [ + ("upsample_bilinear2d.vec_scale_0", (2, 2), (2, 2), True), + ("upsample_bilinear2d.vec_scale_1", (2, 2), (1.5, 1.5), True), + ] + ) + def test_upsample_bilinear_scale_factors( + self, _, input_shape, scale_factors, align_corners + ): + class Upsample(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.upsample_bilinear2d.vec( + input, + None, + align_corners, + scale_factors, + ) + + input = [torch.randn([1, 1] + list(input_shape))] + self.run_test(Upsample(), input) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index edbe93eddd..9070c8373f 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -267,6 +267,114 @@ def forward(self, q, k, v): torch._dynamo.reset() +class TestLowerLinear(TestCase): + def test_lower_linear(self): + class Linear(torch.nn.Module): + def forward(self, input, weight, bias): + out = torch.ops.aten.linear.default(input, weight, bias) + return out + + inputs = [ + torch.rand((3, 32)).cuda(), + torch.rand((64, 32)).cuda(), + torch.rand((64,)).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(Linear()) + expected_ops = {torch.ops.aten.linear.default} + unexpected_ops = { + torch.ops.aten.permute.default, + torch.ops.aten.addmm.default, + } + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Linear TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_lower_linear_batch(self): + class Linear(torch.nn.Module): + def forward(self, input, weight, bias): + out = torch.ops.aten.linear.default(input, weight, bias) + return out + + inputs = [ + torch.rand((2, 2, 32)).cuda(), + torch.rand((64, 32)).cuda(), + torch.rand((64,)).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(Linear()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Linear TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + class TestLowerViewToReshape(TestCase): def test_view_to_reshape(self): class ViewToReshape(torch.nn.Module): diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 95a8c96e94..84e8d11585 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -110,78 +110,6 @@ def forward(self, x): f"The following expected ops were not encountered: {expected_ops_unseen}", ) - def test_lowering_addmm(self): - class AddMM(torch.nn.Module): - def forward(self, x, y, z): - return torch.addmm(x, y, z, beta=16, alpha=5) - - # Operations expected to be included in the traced graph after decompositions - expected_ops = { - torch.ops.aten.add.Tensor, - torch.ops.aten.mul.Tensor, - torch.ops.aten.mm.default, - } - unexpected_ops = {torch.ops.aten.addmm.default} - - inputs = [ - torch.rand( - 1, - 1, - ).cuda(), - torch.rand( - 7, - 8, - ).cuda(), - torch.rand( - 8, - 9, - ).cuda(), - ] - - fx_graph = torch.fx.symbolic_trace(AddMM()) - unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( - fx_graph, - inputs, - expected_ops=expected_ops, - unexpected_ops=unexpected_ops, - min_block_size=1, - ) - - self.assertEquals( - len(unexpected_ops_seen), - 0, - f"The following unexpected ops were encountered: {unexpected_ops_seen}", - ) - - self.assertEquals( - len(expected_ops_unseen), - 0, - f"The following expected ops were not encountered: {expected_ops_unseen}", - ) - - torch._dynamo.reset() - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "torch_compile", - inputs, - min_block_size=1, - pass_through_build_failures=True, - ) - optimized_model_results = optimized_model(*inputs).detach().cpu() - torch_model_results = fx_graph(*inputs).detach().cpu() - - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - f"AddMM TRT outputs don't match with the original model.", - ) - def test_lowering_reciprocal(self): class Reciprocal(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 057a95879d..ec88294c48 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -36,6 +36,7 @@ def forward(self, x): opt_shape=(4, 3, 224, 224), max_shape=(8, 3, 224, 224), dtype=torch.float32, + name="x", ) ], "device": torchtrt.Device("cuda:0"), @@ -88,6 +89,7 @@ def forward(self, x): opt_shape=(4, 3, 224, 224), max_shape=(8, 3, 224, 224), dtype=torch.float32, + name="x", ) ], "device": torchtrt.Device("cuda:0"), @@ -95,7 +97,7 @@ def forward(self, x): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "torch_executed_ops": "torch.ops.aten.abs.default", + "torch_executed_ops": {"torch.ops.aten.abs.default"}, "min_block_size": 1, } diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 3ddc965a51..f1c00e751d 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -5,7 +5,6 @@ import torch import torch_tensorrt as torchtrt import torchvision.models as models -from torch._export.serde.serialize import deserialize, serialize from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase() @@ -44,11 +43,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_exp_program = torchtrt.dynamo.export( - trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program" - ) - serialized_prog = serialize(trt_exp_program) - deserialized_prog = deserialize(*serialized_prog) + trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") # Check Pyt and TRT exported program outputs cos_sim = cosine_similarity(model(input), trt_exp_program(input)[0]) @@ -57,7 +54,7 @@ def forward(self, x): msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Check Pyt and deserialized TRT exported program outputs - cos_sim = cosine_similarity(model(input), deserialized_prog(input)[0]) + cos_sim = cosine_similarity(model(input), deser_trt_exp_program(input)[0]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", @@ -98,11 +95,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_exp_program = torchtrt.dynamo.export( - trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program" - ) - serialized_prog = serialize(trt_exp_program) - deserialized_prog = deserialize(*serialized_prog) + trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") # Check Pyt and TRT exported program outputs outputs_pyt = model(input) outputs_trt = trt_exp_program(input) @@ -114,7 +109,7 @@ def forward(self, x): ) # Check Pyt and deserialized TRT exported program outputs - outputs_trt_deser = deserialized_prog(input) + outputs_trt_deser = deser_trt_exp_program(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) assertions.assertTrue( @@ -157,11 +152,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_exp_program = torchtrt.dynamo.export( - trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program" - ) - torch._export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") outputs_pyt = model(input) outputs_trt = trt_exp_program(input) @@ -213,16 +206,14 @@ def forward(self, x): ], "ir": ir, "min_block_size": 1, - "torch_executed_ops": "torch.ops.aten.relu.default", + "torch_executed_ops": {"torch.ops.aten.relu.default"}, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_exp_program = torchtrt.dynamo.export( - trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program" - ) - torch._export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") outputs_pyt = model(input) outputs_trt = trt_exp_program(input) @@ -262,22 +253,21 @@ def test_resnet18_save_load(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_exp_program = torchtrt.dynamo.export( - trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program" - ) - torch._export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") outputs_pyt = model(input) outputs_trt = trt_exp_program(input) - cos_sim = cosine_similarity(outputs_pyt, outputs_trt) + cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) outputs_trt_deser = deser_trt_exp_program(input) - cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", @@ -319,8 +309,8 @@ def test_resnet18_save_load(ir): # } # trt_exp_program = torchtrt.compile(model, **compile_spec) -# torch._export.save(trt_exp_program, "/tmp/trt.ep") -# deser_trt_exp_program = torch._export.load("/tmp/trt.ep") +# torch.export.save(trt_exp_program, "/tmp/trt.ep") +# deser_trt_exp_program = torch.export.load("/tmp/trt.ep") # outputs_pyt = model(input) # outputs_trt = trt_exp_program(input) diff --git a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel.tmpl b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel.tmpl index ccf9d68917..058e8e946c 100644 --- a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel.tmpl +++ b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel.tmpl @@ -59,14 +59,14 @@ http_archive( name = "libtorch", build_file = "@//third_party/libtorch:BUILD", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/test/cu121/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcu121.zip"], + urls = ["https://download.pytorch.org/libtorch/test/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip"], ) http_archive( name = "libtorch_pre_cxx11_abi", build_file = "@//third_party/libtorch:BUILD", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/test/cu121/libtorch-shared-with-deps-2.1.2%2Bcu121.zip"], + urls = ["https://download.pytorch.org/libtorch/test/cu121/libtorch-shared-with-deps-2.2.0%2Bcu121.zip"], ) #################################################################################### diff --git a/version.txt b/version.txt index cc34478c3b..887948350c 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -2.1.0a0 \ No newline at end of file +2.2.0a0