From 51f1569c77bca39ac723cfa37736e7acccc16f03 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 28 Sep 2021 13:19:19 -0700 Subject: [PATCH] Add checks for structured in-place operations. (#65686) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65686 Fixes: #57827 This PR introduces `check_inplace` function. It contains some common checks for all structured in-place operators (e.g. dtype, device, and sizes). `set_output` method calls `check_inplace` on in-place specializations of structured kernels. Besides that, it also: - adds overlap assertions for both in-place and out-of-place overloads - remove in-place operator specific `TORCH_CHECK` around the code base Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31234063 Pulled By: ezyang fbshipit-source-id: fa3b45775af7812e07a282e7cae00b68caf0fdb0 --- aten/src/ATen/native/Blas.cpp | 4 --- aten/src/ATen/native/LinearAlgebra.cpp | 10 -------- aten/src/ATen/native/ReduceOps.cpp | 6 ----- tools/codegen/dest/register_dispatch_key.py | 27 +++++++++++++++++++-- 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp index c18c65784b5fa..9d2933c444dbf 100644 --- a/aten/src/ATen/native/Blas.cpp +++ b/aten/src/ATen/native/Blas.cpp @@ -17,10 +17,6 @@ TORCH_META_FUNC(addmv)(const Tensor &self, const Tensor &mat, const Tensor &vec, "size mismatch, got ", self.size(0), ", ", mat.size(0), "x", mat.size(1), ",", vec.size(0)); auto names = at::namedinference::propagate_names_for_addmv(mat, vec, self); set_output(0, IntArrayRef(mat.sizes().data(), 1), {}, mat.options(), names); - auto result = maybe_get_output(0); - //this check can fire for inplace op only, for all other versions result is guaranteed to be correct size - TORCH_CHECK(result.dim() == 1 && result.sizes()[0] == mat.sizes()[0], "output of addmv operation should be 1D with ", - "size equal to mat.size(0), yet got output size ", result.sizes(), " and mat.size(0) ", mat.size(0)); } } diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 75fd9e6c61b67..5c3f5bb51002b 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -35,11 +35,6 @@ TORCH_META_FUNC(addmm)(const Tensor& self, const Tensor& mat1, const Tensor& mat auto names = at::namedinference::propagate_names_for_addmm(mat1, mat2, self); set_output(0, {mat1.sizes()[0], mat2.sizes()[1]}, {}, self.options(), names); - const auto& result = maybe_get_output(0); - //this check can fire for inplace op only, for all other versions result is guaranteed to be correct size - TORCH_CHECK(((result.dim() == 2) && (result.sizes()[0] == mat1.sizes()[0]) && (result.sizes()[1] == mat2.sizes()[1])), - "The input tensor must be a matrix with size ", mat1.sizes()[0], "x", mat2.sizes()[1], ", but got a ", result.dim(), - "-D tensor with size ", result.sizes()[0], "x", result.sizes()[1]); } TORCH_META_FUNC(mm)(const Tensor & self, const Tensor & mat2) { @@ -51,11 +46,6 @@ TORCH_META_FUNC(mm)(const Tensor & self, const Tensor & mat2) { auto names = at::namedinference::compute_matmul_outnames(self, mat2); set_output(0, {self.sizes()[0], mat2.sizes()[1]}, {}, self.options(), names); - const auto& result = maybe_get_output(0); - //this check can fire for inplace op only, for all other versions result is guaranteed to be correct size - TORCH_CHECK(((result.dim() == 2) && (result.sizes()[0] == self.sizes()[0]) && (result.sizes()[1] == mat2.sizes()[1])), - "The input tensor must be a matrix with size ", self.sizes()[0], "x", mat2.sizes()[1], ", but got a ", result.dim(), - "-D tensor with size ", result.sizes()[0], "x", result.sizes()[1]); } template diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 3674a35dfbfea..a96478ba91e9d 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -179,12 +179,6 @@ void meta_func_cum_ops( if (result.defined()) { out_dtype = dtype.value_or(result.scalar_type()); - // This check is still here because the inplace version of structured kernels - // does not do any checks on 'set_output'. - TORCH_CHECK( - out_dtype == result.scalar_type(), - name, "(): provided dtype must match dtype of result tensor. Got: ", - toString(out_dtype), ". Expected: ", toString(result.scalar_type())); } else { auto is_integral = at::isIntegralType(self.scalar_type(), /*includeBool=*/true); out_dtype = dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type()); diff --git a/tools/codegen/dest/register_dispatch_key.py b/tools/codegen/dest/register_dispatch_key.py index ec3a2e6afc0b1..7b828dc4657d4 100644 --- a/tools/codegen/dest/register_dispatch_key.py +++ b/tools/codegen/dest/register_dispatch_key.py @@ -88,11 +88,32 @@ def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]: } """] +def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]: + return [""" +void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) { + // These checks are needed on those operators that: + // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm') + // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod') + // For other operators (e.g. 'add'), 'TensorIterator' already checks + // these things separately. + TORCH_CHECK(options.dtype() == self.dtype(), + "Bad in-place call: ", + "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match"); + TORCH_CHECK(options.device() == self.device(), + "Bad in-place call: ", + "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match"); + TORCH_CHECK(sizes == self.sizes(), + "Bad in-place call: ", + "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match"); +} +"""] + def gen_registration_helpers(backend_index: BackendIndex) -> List[str]: return [ *gen_create_out_helper(backend_index), - *gen_resize_out_helper(backend_index) + *gen_resize_out_helper(backend_index), + *gen_check_inplace_helper(backend_index) ] @@ -423,7 +444,9 @@ def gen_class_set_output_body(self, k: SchemaKind) -> str: return f"""{maybe_set_guard_line} outputs_[output_idx] = create_out(sizes, strides, options);""" elif k is SchemaKind.inplace: - return maybe_set_guard + return f"""{maybe_set_guard_line} +const auto& out = outputs_[output_idx].get(); +check_inplace(out, sizes, options);""" elif k is SchemaKind.out: return f"""{maybe_set_guard_line} const auto& out = outputs_[output_idx].get();