Skip to content

Commit 0670951

Browse files
authored
Add aten_hardtanh_backward function (#1715)
Depends on #1707, will add unit test after #1707 merged.
1 parent 619f5ed commit 0670951

File tree

1 file changed

+4
-1
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+4
-1
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -632,12 +632,15 @@ def aten_hardtanh(self: TReal, min_val: float = -1.0, max_val: float = 1.0) -> T
632632
return op.Clip(self, min_val, max_val)
633633

634634

635+
@torch_op("aten::hardtanh_backward", trace_only=True)
635636
def aten_hardtanh_backward(
636637
grad_output: TensorType, self: TensorType, min_val: float, max_val: float
637638
) -> TensorType:
638639
"""hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor"""
639640

640-
raise NotImplementedError()
641+
max_mask = op.Where(op.Greater(self, max_val), 0.0, 1.0)
642+
min_mask = op.Where(op.Less(self, min_val), 0.0, 1.0)
643+
return op.Mul(op.Mul(grad_output, max_mask), min_mask)
641644

642645

643646
def aten_huber_loss(

0 commit comments

Comments
 (0)