From 9ae63bd87c21a59094b3f8344bac8910990d395b Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Thu, 30 Sep 2021 11:31:20 -0700 Subject: [PATCH] Revert D31238123: [pytorch][PR] Avoid saving self for`softmax` and `log_softmax` Test Plan: revert-hammer Differential Revision: D31238123 (https://github.com/pytorch/pytorch/commit/fb412bdd80a04e648492dd26c3400a87e8490198) Original commit changeset: afd319d3676d fbshipit-source-id: b7980d653a4b8322a225f1dd08c2857ecbe5bc94 --- aten/src/ATen/native/SoftMax.cpp | 16 +++++++------- aten/src/ATen/native/cuda/SoftMax.cu | 12 +++++----- aten/src/ATen/native/native_functions.yaml | 8 +++---- aten/src/ATen/native/sparse/SoftMax.cpp | 15 +++++++------ aten/src/ATen/native/sparse/cuda/SoftMax.cu | 17 +++++++------- .../check_backward_compatibility.py | 2 -- tools/autograd/derivatives.yaml | 14 ++++++------ torch/csrc/autograd/FunctionsManual.cpp | 22 +++++++++++++++++-- torch/csrc/autograd/FunctionsManual.h | 1 + torch/csrc/jit/runtime/symbolic_script.cpp | 4 ++-- 10 files changed, 65 insertions(+), 46 deletions(-) diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index b6efe294052f80..b16b2ab3a12770 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -56,7 +56,7 @@ TORCH_META_FUNC(_softmax_backward_data) (const Tensor& grad, const Tensor& output, int64_t dim, - ScalarType input_dtype) { + const Tensor& input) { TensorArg grad_arg{grad, "grad", 1}, output_arg{output, "output", 2}; checkSameSize("softmax_backward", grad_arg, output_arg); @@ -65,7 +65,7 @@ TORCH_META_FUNC(_softmax_backward_data) auto grad_input_options = grad.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); - bool half_to_float = grad.scalar_type() != input_dtype; + bool half_to_float = grad.scalar_type() != input.scalar_type(); if (half_to_float) { // The code below is only valid for the CUDA implementation. It's "okay" // to put it here because half-to-float conversion is not supported by @@ -75,7 +75,7 @@ TORCH_META_FUNC(_softmax_backward_data) // implementation of this kernel and it is not true that the grad type is // float and the input dtype is half (see #63057). if (grad.scalar_type() == ScalarType::Float && - input_dtype == ScalarType::Half) { + input.scalar_type() == ScalarType::Half) { grad_input_options = grad_input_options.dtype(ScalarType::Half); } } @@ -92,12 +92,12 @@ TORCH_META_FUNC(_log_softmax_backward_data) (const Tensor& grad, const Tensor& output, int64_t dim, - ScalarType input_dtype){ + const Tensor& input){ int64_t dim_ = maybe_wrap_dim(dim, grad.dim()); TensorOptions grad_input_options( grad.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT)); - bool half_to_float = grad.scalar_type() != input_dtype; + bool half_to_float = grad.scalar_type() != input.scalar_type(); if (half_to_float) { // The code below is only valid for the CUDA implementation. It's "okay" // to put it here because half-to-float conversion is not supported by @@ -107,7 +107,7 @@ TORCH_META_FUNC(_log_softmax_backward_data) // implementation of this kernel and it is not true that the grad type is // float and the input dtype is half (see #63057). if (grad.scalar_type() == ScalarType::Float && - input_dtype == ScalarType::Half) { + input.scalar_type() == ScalarType::Half) { grad_input_options = grad_input_options.dtype(ScalarType::Half); } } @@ -292,7 +292,7 @@ TORCH_IMPL_FUNC(softmax_backward_cpu_out) (const Tensor& grad, const Tensor& output, int64_t dim, - ScalarType input_dtype, + const Tensor& input, const Tensor& grad_input) { int64_t dim_ = maybe_wrap_dim(dim, grad.dim()); auto grad_ = grad.contiguous(); @@ -324,7 +324,7 @@ TORCH_IMPL_FUNC(log_softmax_backward_cpu_out) ( const Tensor& grad, const Tensor& output, int64_t dim, - ScalarType input_dtype, + const Tensor& input, const Tensor& grad_input) { int64_t dim_ = maybe_wrap_dim(dim, grad.dim()); auto grad_ = grad.contiguous(); diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index 09530cb38a6adf..8fd1c530ba57f8 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -906,13 +906,13 @@ TORCH_IMPL_FUNC(log_softmax_backward_cuda_out) ( const Tensor& grad, const Tensor& output, int64_t dim, - ScalarType input_dtype, + const Tensor& input, const Tensor& grad_input) { - bool half_to_float = grad.scalar_type() != input_dtype; + bool half_to_float = grad.scalar_type() != input.scalar_type(); if (half_to_float) { TORCH_CHECK( (grad.scalar_type() == ScalarType::Float && - input_dtype == ScalarType::Half), + input.scalar_type() == ScalarType::Half), "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); } host_softmax_backward(grad, output, dim, half_to_float, grad_input); @@ -930,13 +930,13 @@ TORCH_IMPL_FUNC(softmax_backward_cuda_out) (const Tensor& grad, const Tensor& output, int64_t dim, - ScalarType input_dtype, + const Tensor& input, const Tensor& grad_input) { - bool half_to_float = grad.scalar_type() != input_dtype; + bool half_to_float = grad.scalar_type() != input.scalar_type(); if (half_to_float) { TORCH_CHECK( (grad.scalar_type() == ScalarType::Float && - input_dtype == ScalarType::Half), + input.scalar_type() == ScalarType::Half), "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); } Tensor tmp = grad * output; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 72bce02de4e811..0595665df81181 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2653,10 +2653,10 @@ CPU: log_softmax_cpu_out CUDA: log_softmax_cuda_out -- func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor +- func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor structured_delegate: _log_softmax_backward_data.out -- func: _log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!) +- func: _log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: CPU: log_softmax_backward_cpu_out @@ -3910,10 +3910,10 @@ CPU: softmax_cpu_out CUDA: softmax_cuda_out -- func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor +- func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor structured_delegate: _softmax_backward_data.out -- func: _softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!) +- func: _softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) structured: True dispatch: CPU: softmax_backward_cpu_out diff --git a/aten/src/ATen/native/sparse/SoftMax.cpp b/aten/src/ATen/native/sparse/SoftMax.cpp index 29662b00a1837d..53a902d1fe3d6c 100644 --- a/aten/src/ATen/native/sparse/SoftMax.cpp +++ b/aten/src/ATen/native/sparse/SoftMax.cpp @@ -372,7 +372,7 @@ void cpu_sparse_coo_softmax(Tensor output, const Tensor& input, const int64_t di } template -void cpu_sparse_coo_softmax_backward(const Tensor& grad_input, const Tensor& grad, const Tensor& output, const int64_t dim, ScalarType input_dtype) { +void cpu_sparse_coo_softmax_backward(const Tensor& grad_input, const Tensor& grad, const Tensor& output, const int64_t dim) { /* If LogSoftMax == false, then @@ -411,13 +411,14 @@ void cpu_sparse_coo_softmax_backward(const Tensor& grad_input, const Tensor& gra auto grad_offsets = get_offsets(grad_indices, sizes, -1); if (dim >= sparse_dim) { + Tensor unused; if (out_offsets == grad_offsets) { if (LogSoftMax) { auto r = at::cpu::_log_softmax_backward_data( - grad_values, out_values, dim - sparse_dim + 1, input_dtype); + grad_values, out_values, dim - sparse_dim + 1, unused); values.set_(r); } else { - auto r = at::cpu::_softmax_backward_data(grad_values, out_values, dim - sparse_dim + 1, input_dtype); + auto r = at::cpu::_softmax_backward_data(grad_values, out_values, dim - sparse_dim + 1, unused); values.set_(r); } } else { @@ -427,10 +428,10 @@ void cpu_sparse_coo_softmax_backward(const Tensor& grad_input, const Tensor& gra if (j < grad_nnz && out_offsets[i] == grad_offsets[j]) { if (LogSoftMax) { auto r = at::cpu::_log_softmax_backward_data( - grad_values[j], out_values[i], dim - sparse_dim, input_dtype); + grad_values[j], out_values[i], dim - sparse_dim, unused); values[i].copy_(r); } else { - auto r = at::cpu::_softmax_backward_data(grad_values[j], out_values[i], dim - sparse_dim, input_dtype); + auto r = at::cpu::_softmax_backward_data(grad_values[j], out_values[i], dim - sparse_dim, unused); values[i].copy_(r); } } @@ -561,7 +562,7 @@ Tensor softmax_backward_sparse_cpu( } AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "softmax_backward", [&] { cpu_sparse_coo_softmax_backward( - grad_input, grad, output, dim_, input_.scalar_type()); + grad_input, grad, output, dim_); }); return grad_input; } @@ -580,7 +581,7 @@ Tensor log_softmax_backward_sparse_cpu( } AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "log_softmax_backward", [&] { cpu_sparse_coo_softmax_backward( - grad_input, grad, output, dim_, input_.scalar_type()); + grad_input, grad, output, dim_); }); return grad_input; } diff --git a/aten/src/ATen/native/sparse/cuda/SoftMax.cu b/aten/src/ATen/native/sparse/cuda/SoftMax.cu index 42561155a8a897..f5e4d98050503b 100644 --- a/aten/src/ATen/native/sparse/cuda/SoftMax.cu +++ b/aten/src/ATen/native/sparse/cuda/SoftMax.cu @@ -431,8 +431,7 @@ void cuda_sparse_coo_softmax_backward( Tensor& grad_input, const Tensor& grad, const Tensor& output, - const int64_t dim, - ScalarType input_dtype) { + const int64_t dim) { /* See ATen/native/sparse/Softmax.cpp:cpu_sparse_coo_softmax_backward for the CPU implementation of the sparse softmax backward algorithm that this @@ -464,12 +463,13 @@ void cuda_sparse_coo_softmax_backward( /* when dim >= sparse_dim the dense backward is used */ if (dim >= sparse_dim) { if (at::native::cuda_equal(out_offsets, grad_offsets) == true) { + Tensor unused = at::native::empty_like(grad_values); if (LogSoftMax) { auto r = at::cuda::_log_softmax_backward_data( - grad_values, out_values, dim - sparse_dim + 1, input_dtype); + grad_values, out_values, dim - sparse_dim + 1, unused); values.set_(r); } else { - auto r = at::cuda::_softmax_backward_data(grad_values, out_values, dim - sparse_dim + 1, input_dtype); + auto r = at::cuda::_softmax_backward_data(grad_values, out_values, dim - sparse_dim + 1, unused); values.set_(r); } } else { @@ -480,6 +480,7 @@ void cuda_sparse_coo_softmax_backward( auto out_offsets_accessor = host_out_offsets.data_ptr(); auto grad_offsets_accessor = host_grad_offsets.data_ptr(); for (int64_t i = 0; i < out_nnz; i++) { + Tensor unused = at::native::empty_like(grad_values); auto low = thrust::lower_bound( grad_offsets_accessor, grad_offsets_accessor + grad_offsets.size(0), @@ -492,11 +493,11 @@ void cuda_sparse_coo_softmax_backward( if (j < grad_nnz && out_offsets_accessor[i] == grad_offsets_accessor[j]) { if (LogSoftMax) { auto r = at::cuda::_log_softmax_backward_data( - grad_values[j], out_values[i], dim - sparse_dim, input_dtype); + grad_values[j], out_values[i], dim - sparse_dim, unused); values[i].copy_(r); } else { auto r = at::cuda::_softmax_backward_data( - grad_values[j], out_values[i], dim - sparse_dim, input_dtype); + grad_values[j], out_values[i], dim - sparse_dim, unused); values[i].copy_(r); } } @@ -608,7 +609,7 @@ Tensor softmax_backward_sparse_cuda( } AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "softmax_backward", [&] { cuda_sparse_coo_softmax_backward( - grad_input, grad, output, dim_, input_.scalar_type()); + grad_input, grad, output, dim_); }); return grad_input; } @@ -628,7 +629,7 @@ Tensor log_softmax_backward_sparse_cuda( AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "log_softmax_backward", [&] { cuda_sparse_coo_softmax_backward( - grad_input, grad, output, dim_, input_.scalar_type()); + grad_input, grad, output, dim_); }); return grad_input; } diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index a79e031e9b61e9..f884ff7622a292 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -52,8 +52,6 @@ ("aten::randperm", datetime.date(9999, 1, 1)), ("aten::thnn_conv2d_forward", datetime.date(2021, 9, 30)), ("aten::thnn_conv2d_backward", datetime.date(2021, 9, 30)), - ("aten::_log_softmax_backward_data", datetime.date(2021, 10, 21)), - ("aten::_softmax_backward_data", datetime.date(2021, 10, 21)) ] ALLOW_LIST_COMPILED = [ diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 143bd379bad58b..14823e09a65f27 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1683,7 +1683,7 @@ self: log_sigmoid_backward(grad, self, buffer) - name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor - self: _log_softmax_backward_data(grad, result, dim, self.scalar_type()) + self: _log_softmax_backward_data(grad, result, dim, self) - name: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor self: _sparse_log_softmax_backward_data(grad, result, dim, self) @@ -1701,7 +1701,7 @@ self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true) - name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor - self: _softmax_backward_data(grad, result, dim, self.scalar_type()) + self: _softmax_backward_data(grad, result, dim, self) - name: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor self: _sparse_softmax_backward_data(grad, result, dim, self) @@ -1948,9 +1948,9 @@ grad_output: log_sigmoid_backward(grad, self, buffer) self: log_sigmoid_double_backward(grad * grad_output, self) -- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor +- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true) - output: (-grad_output.sum(dim, true) * output.exp() * grad.to(output.dtype())).to(output.dtype()) + self: log_softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(self.dtype()) - name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor # self_is_result is always false here since double backward call is an out-of-place call, self is input itself @@ -2030,9 +2030,9 @@ grad_output: softplus_backward(grad, self, beta, threshold, output) self: softplus_double_backward(grad * grad_output, self, beta, threshold) -- name: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor - grad_output: _softmax_backward_data(grad.to(output.dtype()), output, dim, input_dtype) - output: softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(output.dtype()) +- name: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor + grad_output: _softmax_backward_data(grad.to(output.dtype()), output, dim, self) + self: softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(self.dtype()) - name: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor grad_output: soft_margin_loss_double_backward_grad_output(grad, grad_output, self, target, reduction) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index bb817b486cd80c..7232f0b741f14c 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1234,8 +1234,26 @@ Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) { return grad * (z - 1) * z; } -Tensor softmax_double_backward(const Tensor& grad, const Tensor& grad_output, int dim, const Tensor& output) { - return grad_output * grad - (output * grad_output).sum(dim, true) * grad - grad_output * (output * grad).sum(dim, true); +Tensor softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) { + const auto& gO = grad_output; + const auto& ggI = grad; + + auto ggI_output = ggI * output; + auto ggI_out_sum = ggI_output.sum(dim, true); + auto ggI_out_sum_output = ggI_out_sum * output; + auto gO_out_sum = (gO * output).sum(dim, true); + + // gI calculation + auto gI_t0 = ggI_output * (gO - gO_out_sum); + auto gI_t1 = output * ((ggI_output * gO).sum(dim, true).sub_(gO_out_sum * ggI_out_sum)); + auto gI_t2 = ggI_out_sum_output * gO; + auto gI_t3 = ggI_out_sum_output * gO_out_sum; + return gI_t0 - gI_t1 - gI_t2 + gI_t3; +} + +Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) { + auto z = output.exp(); + return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad); } // NOTE: [How to write vmap-compatible backward formulas] diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index eba38770e35163..2e1a36f1b6db99 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -122,6 +122,7 @@ Tensor binary_cross_entropy_target_backward( at::Tensor binary_cross_entropy_with_logits_target_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target, const c10::optional& weight, const c10::optional& pos_weight, int64_t reduction); at::Tensor log_sigmoid_double_backward(const at::Tensor & grad, const at::Tensor & input); at::Tensor softmax_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, int dim, const at::Tensor & output); +at::Tensor log_softmax_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, int dim, const at::Tensor & output); at::Tensor binary_cross_entropy_double_backward(const at::Tensor & grad_output, const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, const c10::optional& weight, int64_t reduction); at::Tensor binary_cross_entropy_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, const c10::optional& weight, int64_t reduction); at::Tensor l1_loss_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction); diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index 3645864226e436..2cf56ff111d482 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -1208,7 +1208,7 @@ const std::vector functions = { def log_softmax(self, dim: int, dtype: Optional[int]): result = torch.log_softmax(self, dim, dtype) def backward(grad_output): - grad_self = torch._log_softmax_backward_data(grad_output, result, dim, self.dtype) + grad_self = torch._log_softmax_backward_data(grad_output, result, dim, self) return grad_self, None, None return result, backward @@ -1222,7 +1222,7 @@ const std::vector functions = { def softmax(self, dim: int, dtype: Optional[int]): result = torch.softmax(self, dim, dtype) def backward(grad_output): - grad_self = torch._softmax_backward_data(grad_output, result, dim, self.dtype) + grad_self = torch._softmax_backward_data(grad_output, result, dim, self) return grad_self, None, None return result, backward