Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
Add checks for structured in-place operations. (pytorch#65686)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#65686

Fixes: pytorch#57827

This PR introduces `check_inplace` function. It contains some common checks for all
structured in-place operators (e.g. dtype, device, and sizes). `set_output` method calls
`check_inplace` on in-place specializations of structured kernels.

Besides that, it also:
- adds overlap assertions for both in-place and out-of-place overloads
- remove in-place operator specific `TORCH_CHECK` around the code base

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D31234063

Pulled By: ezyang

fbshipit-source-id: fa3b45775af7812e07a282e7cae00b68caf0fdb0
  • Loading branch information
ysiraichi authored and facebook-github-bot committed Sep 28, 2021
1 parent 93852bb commit 51f1569
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 22 deletions.
4 changes: 0 additions & 4 deletions aten/src/ATen/native/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ TORCH_META_FUNC(addmv)(const Tensor &self, const Tensor &mat, const Tensor &vec,
"size mismatch, got ", self.size(0), ", ", mat.size(0), "x", mat.size(1), ",", vec.size(0));
auto names = at::namedinference::propagate_names_for_addmv(mat, vec, self);
set_output(0, IntArrayRef(mat.sizes().data(), 1), {}, mat.options(), names);
auto result = maybe_get_output(0);
//this check can fire for inplace op only, for all other versions result is guaranteed to be correct size
TORCH_CHECK(result.dim() == 1 && result.sizes()[0] == mat.sizes()[0], "output of addmv operation should be 1D with ",
"size equal to mat.size(0), yet got output size ", result.sizes(), " and mat.size(0) ", mat.size(0));
}
}

Expand Down
10 changes: 0 additions & 10 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ TORCH_META_FUNC(addmm)(const Tensor& self, const Tensor& mat1, const Tensor& mat

auto names = at::namedinference::propagate_names_for_addmm(mat1, mat2, self);
set_output(0, {mat1.sizes()[0], mat2.sizes()[1]}, {}, self.options(), names);
const auto& result = maybe_get_output(0);
//this check can fire for inplace op only, for all other versions result is guaranteed to be correct size
TORCH_CHECK(((result.dim() == 2) && (result.sizes()[0] == mat1.sizes()[0]) && (result.sizes()[1] == mat2.sizes()[1])),
"The input tensor must be a matrix with size ", mat1.sizes()[0], "x", mat2.sizes()[1], ", but got a ", result.dim(),
"-D tensor with size ", result.sizes()[0], "x", result.sizes()[1]);
}

TORCH_META_FUNC(mm)(const Tensor & self, const Tensor & mat2) {
Expand All @@ -51,11 +46,6 @@ TORCH_META_FUNC(mm)(const Tensor & self, const Tensor & mat2) {

auto names = at::namedinference::compute_matmul_outnames(self, mat2);
set_output(0, {self.sizes()[0], mat2.sizes()[1]}, {}, self.options(), names);
const auto& result = maybe_get_output(0);
//this check can fire for inplace op only, for all other versions result is guaranteed to be correct size
TORCH_CHECK(((result.dim() == 2) && (result.sizes()[0] == self.sizes()[0]) && (result.sizes()[1] == mat2.sizes()[1])),
"The input tensor must be a matrix with size ", self.sizes()[0], "x", mat2.sizes()[1], ", but got a ", result.dim(),
"-D tensor with size ", result.sizes()[0], "x", result.sizes()[1]);
}

template <typename Meta>
Expand Down
6 changes: 0 additions & 6 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,6 @@ void meta_func_cum_ops(

if (result.defined()) {
out_dtype = dtype.value_or(result.scalar_type());
// This check is still here because the inplace version of structured kernels
// does not do any checks on 'set_output'.
TORCH_CHECK(
out_dtype == result.scalar_type(),
name, "(): provided dtype must match dtype of result tensor. Got: ",
toString(out_dtype), ". Expected: ", toString(result.scalar_type()));
} else {
auto is_integral = at::isIntegralType(self.scalar_type(), /*includeBool=*/true);
out_dtype = dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type());
Expand Down
27 changes: 25 additions & 2 deletions tools/codegen/dest/register_dispatch_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,32 @@ def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]:
}
"""]

def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]:
return ["""
void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
// These checks are needed on those operators that:
// 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
// 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
// For other operators (e.g. 'add'), 'TensorIterator' already checks
// these things separately.
TORCH_CHECK(options.dtype() == self.dtype(),
"Bad in-place call: ",
"input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
TORCH_CHECK(options.device() == self.device(),
"Bad in-place call: ",
"input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
TORCH_CHECK(sizes == self.sizes(),
"Bad in-place call: ",
"input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
}
"""]


def gen_registration_helpers(backend_index: BackendIndex) -> List[str]:
return [
*gen_create_out_helper(backend_index),
*gen_resize_out_helper(backend_index)
*gen_resize_out_helper(backend_index),
*gen_check_inplace_helper(backend_index)
]


Expand Down Expand Up @@ -423,7 +444,9 @@ def gen_class_set_output_body(self, k: SchemaKind) -> str:
return f"""{maybe_set_guard_line}
outputs_[output_idx] = create_out(sizes, strides, options);"""
elif k is SchemaKind.inplace:
return maybe_set_guard
return f"""{maybe_set_guard_line}
const auto& out = outputs_[output_idx].get();
check_inplace(out, sizes, options);"""
elif k is SchemaKind.out:
return f"""{maybe_set_guard_line}
const auto& out = outputs_[output_idx].get();
Expand Down

0 comments on commit 51f1569

Please sign in to comment.