Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
Revert D31238123: [pytorch][PR] Avoid saving self forsoftmax and `l…
Browse files Browse the repository at this point in the history
…og_softmax`

Test Plan: revert-hammer

Differential Revision:
D31238123 (pytorch@fb412bd)

Original commit changeset: afd319d3676d

fbshipit-source-id: b7980d653a4b8322a225f1dd08c2857ecbe5bc94
  • Loading branch information
suo authored and facebook-github-bot committed Sep 30, 2021
1 parent 541eb1d commit 9ae63bd
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 46 deletions.
16 changes: 8 additions & 8 deletions aten/src/ATen/native/SoftMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TORCH_META_FUNC(_softmax_backward_data)
(const Tensor& grad,
const Tensor& output,
int64_t dim,
ScalarType input_dtype) {
const Tensor& input) {
TensorArg grad_arg{grad, "grad", 1}, output_arg{output, "output", 2};
checkSameSize("softmax_backward", grad_arg, output_arg);

Expand All @@ -65,7 +65,7 @@ TORCH_META_FUNC(_softmax_backward_data)
auto grad_input_options =
grad.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);

bool half_to_float = grad.scalar_type() != input_dtype;
bool half_to_float = grad.scalar_type() != input.scalar_type();
if (half_to_float) {
// The code below is only valid for the CUDA implementation. It's "okay"
// to put it here because half-to-float conversion is not supported by
Expand All @@ -75,7 +75,7 @@ TORCH_META_FUNC(_softmax_backward_data)
// implementation of this kernel and it is not true that the grad type is
// float and the input dtype is half (see #63057).
if (grad.scalar_type() == ScalarType::Float &&
input_dtype == ScalarType::Half) {
input.scalar_type() == ScalarType::Half) {
grad_input_options = grad_input_options.dtype(ScalarType::Half);
}
}
Expand All @@ -92,12 +92,12 @@ TORCH_META_FUNC(_log_softmax_backward_data)
(const Tensor& grad,
const Tensor& output,
int64_t dim,
ScalarType input_dtype){
const Tensor& input){
int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
TensorOptions grad_input_options(
grad.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT));

bool half_to_float = grad.scalar_type() != input_dtype;
bool half_to_float = grad.scalar_type() != input.scalar_type();
if (half_to_float) {
// The code below is only valid for the CUDA implementation. It's "okay"
// to put it here because half-to-float conversion is not supported by
Expand All @@ -107,7 +107,7 @@ TORCH_META_FUNC(_log_softmax_backward_data)
// implementation of this kernel and it is not true that the grad type is
// float and the input dtype is half (see #63057).
if (grad.scalar_type() == ScalarType::Float &&
input_dtype == ScalarType::Half) {
input.scalar_type() == ScalarType::Half) {
grad_input_options = grad_input_options.dtype(ScalarType::Half);
}
}
Expand Down Expand Up @@ -292,7 +292,7 @@ TORCH_IMPL_FUNC(softmax_backward_cpu_out)
(const Tensor& grad,
const Tensor& output,
int64_t dim,
ScalarType input_dtype,
const Tensor& input,
const Tensor& grad_input) {
int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
auto grad_ = grad.contiguous();
Expand Down Expand Up @@ -324,7 +324,7 @@ TORCH_IMPL_FUNC(log_softmax_backward_cpu_out) (
const Tensor& grad,
const Tensor& output,
int64_t dim,
ScalarType input_dtype,
const Tensor& input,
const Tensor& grad_input) {
int64_t dim_ = maybe_wrap_dim(dim, grad.dim());
auto grad_ = grad.contiguous();
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -906,13 +906,13 @@ TORCH_IMPL_FUNC(log_softmax_backward_cuda_out) (
const Tensor& grad,
const Tensor& output,
int64_t dim,
ScalarType input_dtype,
const Tensor& input,
const Tensor& grad_input) {
bool half_to_float = grad.scalar_type() != input_dtype;
bool half_to_float = grad.scalar_type() != input.scalar_type();
if (half_to_float) {
TORCH_CHECK(
(grad.scalar_type() == ScalarType::Float &&
input_dtype == ScalarType::Half),
input.scalar_type() == ScalarType::Half),
"expected input and grad types to match, or input to be at::Half and grad to be at::Float");
}
host_softmax_backward<LogSoftMaxBackwardEpilogue,true>(grad, output, dim, half_to_float, grad_input);
Expand All @@ -930,13 +930,13 @@ TORCH_IMPL_FUNC(softmax_backward_cuda_out)
(const Tensor& grad,
const Tensor& output,
int64_t dim,
ScalarType input_dtype,
const Tensor& input,
const Tensor& grad_input) {
bool half_to_float = grad.scalar_type() != input_dtype;
bool half_to_float = grad.scalar_type() != input.scalar_type();
if (half_to_float) {
TORCH_CHECK(
(grad.scalar_type() == ScalarType::Float &&
input_dtype == ScalarType::Half),
input.scalar_type() == ScalarType::Half),
"expected input and grad types to match, or input to be at::Half and grad to be at::Float");
}
Tensor tmp = grad * output;
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2653,10 +2653,10 @@
CPU: log_softmax_cpu_out
CUDA: log_softmax_cuda_out

- func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
- func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
structured_delegate: _log_softmax_backward_data.out

- func: _log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!)
- func: _log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU: log_softmax_backward_cpu_out
Expand Down Expand Up @@ -3910,10 +3910,10 @@
CPU: softmax_cpu_out
CUDA: softmax_cuda_out

- func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
- func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
structured_delegate: _softmax_backward_data.out

- func: _softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!)
- func: _softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
structured: True
dispatch:
CPU: softmax_backward_cpu_out
Expand Down
15 changes: 8 additions & 7 deletions aten/src/ATen/native/sparse/SoftMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ void cpu_sparse_coo_softmax(Tensor output, const Tensor& input, const int64_t di
}

template <typename scalar_t, bool LogSoftMax>
void cpu_sparse_coo_softmax_backward(const Tensor& grad_input, const Tensor& grad, const Tensor& output, const int64_t dim, ScalarType input_dtype) {
void cpu_sparse_coo_softmax_backward(const Tensor& grad_input, const Tensor& grad, const Tensor& output, const int64_t dim) {
/*
If LogSoftMax == false, then
Expand Down Expand Up @@ -411,13 +411,14 @@ void cpu_sparse_coo_softmax_backward(const Tensor& grad_input, const Tensor& gra
auto grad_offsets = get_offsets(grad_indices, sizes, -1);

if (dim >= sparse_dim) {
Tensor unused;
if (out_offsets == grad_offsets) {
if (LogSoftMax) {
auto r = at::cpu::_log_softmax_backward_data(
grad_values, out_values, dim - sparse_dim + 1, input_dtype);
grad_values, out_values, dim - sparse_dim + 1, unused);
values.set_(r);
} else {
auto r = at::cpu::_softmax_backward_data(grad_values, out_values, dim - sparse_dim + 1, input_dtype);
auto r = at::cpu::_softmax_backward_data(grad_values, out_values, dim - sparse_dim + 1, unused);
values.set_(r);
}
} else {
Expand All @@ -427,10 +428,10 @@ void cpu_sparse_coo_softmax_backward(const Tensor& grad_input, const Tensor& gra
if (j < grad_nnz && out_offsets[i] == grad_offsets[j]) {
if (LogSoftMax) {
auto r = at::cpu::_log_softmax_backward_data(
grad_values[j], out_values[i], dim - sparse_dim, input_dtype);
grad_values[j], out_values[i], dim - sparse_dim, unused);
values[i].copy_(r);
} else {
auto r = at::cpu::_softmax_backward_data(grad_values[j], out_values[i], dim - sparse_dim, input_dtype);
auto r = at::cpu::_softmax_backward_data(grad_values[j], out_values[i], dim - sparse_dim, unused);
values[i].copy_(r);
}
}
Expand Down Expand Up @@ -561,7 +562,7 @@ Tensor softmax_backward_sparse_cpu(
}
AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "softmax_backward", [&] {
cpu_sparse_coo_softmax_backward<scalar_t, false>(
grad_input, grad, output, dim_, input_.scalar_type());
grad_input, grad, output, dim_);
});
return grad_input;
}
Expand All @@ -580,7 +581,7 @@ Tensor log_softmax_backward_sparse_cpu(
}
AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "log_softmax_backward", [&] {
cpu_sparse_coo_softmax_backward<scalar_t, true>(
grad_input, grad, output, dim_, input_.scalar_type());
grad_input, grad, output, dim_);
});
return grad_input;
}
Expand Down
17 changes: 9 additions & 8 deletions aten/src/ATen/native/sparse/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,7 @@ void cuda_sparse_coo_softmax_backward(
Tensor& grad_input,
const Tensor& grad,
const Tensor& output,
const int64_t dim,
ScalarType input_dtype) {
const int64_t dim) {
/*
See ATen/native/sparse/Softmax.cpp:cpu_sparse_coo_softmax_backward for
the CPU implementation of the sparse softmax backward algorithm that this
Expand Down Expand Up @@ -464,12 +463,13 @@ void cuda_sparse_coo_softmax_backward(
/* when dim >= sparse_dim the dense backward is used */
if (dim >= sparse_dim) {
if (at::native::cuda_equal(out_offsets, grad_offsets) == true) {
Tensor unused = at::native::empty_like(grad_values);
if (LogSoftMax) {
auto r = at::cuda::_log_softmax_backward_data(
grad_values, out_values, dim - sparse_dim + 1, input_dtype);
grad_values, out_values, dim - sparse_dim + 1, unused);
values.set_(r);
} else {
auto r = at::cuda::_softmax_backward_data(grad_values, out_values, dim - sparse_dim + 1, input_dtype);
auto r = at::cuda::_softmax_backward_data(grad_values, out_values, dim - sparse_dim + 1, unused);
values.set_(r);
}
} else {
Expand All @@ -480,6 +480,7 @@ void cuda_sparse_coo_softmax_backward(
auto out_offsets_accessor = host_out_offsets.data_ptr<int64_t>();
auto grad_offsets_accessor = host_grad_offsets.data_ptr<int64_t>();
for (int64_t i = 0; i < out_nnz; i++) {
Tensor unused = at::native::empty_like(grad_values);
auto low = thrust::lower_bound(
grad_offsets_accessor,
grad_offsets_accessor + grad_offsets.size(0),
Expand All @@ -492,11 +493,11 @@ void cuda_sparse_coo_softmax_backward(
if (j < grad_nnz && out_offsets_accessor[i] == grad_offsets_accessor[j]) {
if (LogSoftMax) {
auto r = at::cuda::_log_softmax_backward_data(
grad_values[j], out_values[i], dim - sparse_dim, input_dtype);
grad_values[j], out_values[i], dim - sparse_dim, unused);
values[i].copy_(r);
} else {
auto r = at::cuda::_softmax_backward_data(
grad_values[j], out_values[i], dim - sparse_dim, input_dtype);
grad_values[j], out_values[i], dim - sparse_dim, unused);
values[i].copy_(r);
}
}
Expand Down Expand Up @@ -608,7 +609,7 @@ Tensor softmax_backward_sparse_cuda(
}
AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "softmax_backward", [&] {
cuda_sparse_coo_softmax_backward<scalar_t, false>(
grad_input, grad, output, dim_, input_.scalar_type());
grad_input, grad, output, dim_);
});
return grad_input;
}
Expand All @@ -628,7 +629,7 @@ Tensor log_softmax_backward_sparse_cuda(

AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "log_softmax_backward", [&] {
cuda_sparse_coo_softmax_backward<scalar_t, true>(
grad_input, grad, output, dim_, input_.scalar_type());
grad_input, grad, output, dim_);
});
return grad_input;
}
Expand Down
2 changes: 0 additions & 2 deletions test/backward_compatibility/check_backward_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@
("aten::randperm", datetime.date(9999, 1, 1)),
("aten::thnn_conv2d_forward", datetime.date(2021, 9, 30)),
("aten::thnn_conv2d_backward", datetime.date(2021, 9, 30)),
("aten::_log_softmax_backward_data", datetime.date(2021, 10, 21)),
("aten::_softmax_backward_data", datetime.date(2021, 10, 21))
]

ALLOW_LIST_COMPILED = [
Expand Down
14 changes: 7 additions & 7 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,7 @@
self: log_sigmoid_backward(grad, self, buffer)

- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _log_softmax_backward_data(grad, result, dim, self.scalar_type())
self: _log_softmax_backward_data(grad, result, dim, self)

- name: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _sparse_log_softmax_backward_data(grad, result, dim, self)
Expand All @@ -1701,7 +1701,7 @@
self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true)

- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
self: _softmax_backward_data(grad, result, dim, self)

- name: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _sparse_softmax_backward_data(grad, result, dim, self)
Expand Down Expand Up @@ -1948,9 +1948,9 @@
grad_output: log_sigmoid_backward(grad, self, buffer)
self: log_sigmoid_double_backward(grad * grad_output, self)

- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true)
output: (-grad_output.sum(dim, true) * output.exp() * grad.to(output.dtype())).to(output.dtype())
self: log_softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(self.dtype())

- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor
# self_is_result is always false here since double backward call is an out-of-place call, self is input itself
Expand Down Expand Up @@ -2030,9 +2030,9 @@
grad_output: softplus_backward(grad, self, beta, threshold, output)
self: softplus_double_backward(grad * grad_output, self, beta, threshold)

- name: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
grad_output: _softmax_backward_data(grad.to(output.dtype()), output, dim, input_dtype)
output: softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(output.dtype())
- name: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
grad_output: _softmax_backward_data(grad.to(output.dtype()), output, dim, self)
self: softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(self.dtype())

- name: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
grad_output: soft_margin_loss_double_backward_grad_output(grad, grad_output, self, target, reduction)
Expand Down
22 changes: 20 additions & 2 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1234,8 +1234,26 @@ Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) {
return grad * (z - 1) * z;
}

Tensor softmax_double_backward(const Tensor& grad, const Tensor& grad_output, int dim, const Tensor& output) {
return grad_output * grad - (output * grad_output).sum(dim, true) * grad - grad_output * (output * grad).sum(dim, true);
Tensor softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
const auto& gO = grad_output;
const auto& ggI = grad;

auto ggI_output = ggI * output;
auto ggI_out_sum = ggI_output.sum(dim, true);
auto ggI_out_sum_output = ggI_out_sum * output;
auto gO_out_sum = (gO * output).sum(dim, true);

// gI calculation
auto gI_t0 = ggI_output * (gO - gO_out_sum);
auto gI_t1 = output * ((ggI_output * gO).sum(dim, true).sub_(gO_out_sum * ggI_out_sum));
auto gI_t2 = ggI_out_sum_output * gO;
auto gI_t3 = ggI_out_sum_output * gO_out_sum;
return gI_t0 - gI_t1 - gI_t2 + gI_t3;
}

Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
auto z = output.exp();
return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad);
}

// NOTE: [How to write vmap-compatible backward formulas]
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Tensor binary_cross_entropy_target_backward(
at::Tensor binary_cross_entropy_with_logits_target_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target, const c10::optional<at::Tensor>& weight, const c10::optional<at::Tensor>& pos_weight, int64_t reduction);
at::Tensor log_sigmoid_double_backward(const at::Tensor & grad, const at::Tensor & input);
at::Tensor softmax_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, int dim, const at::Tensor & output);
at::Tensor log_softmax_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, int dim, const at::Tensor & output);
at::Tensor binary_cross_entropy_double_backward(const at::Tensor & grad_output, const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, const c10::optional<at::Tensor>& weight, int64_t reduction);
at::Tensor binary_cross_entropy_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, const c10::optional<at::Tensor>& weight, int64_t reduction);
at::Tensor l1_loss_double_backward(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction);
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/runtime/symbolic_script.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@ const std::vector<std::string> functions = {
def log_softmax(self, dim: int, dtype: Optional[int]):
result = torch.log_softmax(self, dim, dtype)
def backward(grad_output):
grad_self = torch._log_softmax_backward_data(grad_output, result, dim, self.dtype)
grad_self = torch._log_softmax_backward_data(grad_output, result, dim, self)
return grad_self, None, None
return result, backward
Expand All @@ -1222,7 +1222,7 @@ const std::vector<std::string> functions = {
def softmax(self, dim: int, dtype: Optional[int]):
result = torch.softmax(self, dim, dtype)
def backward(grad_output):
grad_self = torch._softmax_backward_data(grad_output, result, dim, self.dtype)
grad_self = torch._softmax_backward_data(grad_output, result, dim, self)
return grad_self, None, None
return result, backward
Expand Down

0 comments on commit 9ae63bd

Please sign in to comment.