Skip to content

Commit 52f8ece

Browse files
caic99iProzd
andauthored
perf: calculate grad on-the-fly for SiLUT (#4678)
Current implementation of SiLUT involves one extra storage for saving the 1st-order gradient. This PR reduces the memory footprint by calculating the 1st-order gradient on-the-fly in `silut_double_backward`. It introduces an overhead of ~0.5% of calculation time. I've tested this PR on OMat with 9 DPA-3 layers and batch size=auto:512. | Metric | Before | After | Improvement | |------------------------|----------|----------|-------------| | Peak Memory | 25.0G | 21.7G | -13% | | Speed (per 100 steps) | 30.29s | 30.46s | -0.56% | The correctness of this modification is covered by `source/tests/pt/test_custom_activation.py`. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Streamlined internal computation logic by refining variable naming for clarity. - Updated public method signatures to return outputs in a structured tuple, ensuring more intuitive and consistent integration. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chun Cai <[email protected]> Co-authored-by: Duo <[email protected]>
1 parent e1b7a9f commit 52f8ece

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

deepmd/pt/utils/utils.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@ def silut_forward(
2626
) -> torch.Tensor:
2727
sig = torch.sigmoid(x)
2828
silu = x * sig
29-
tanh_part = torch.tanh(slope * (x - threshold)) + const_val
30-
return torch.where(x >= threshold, tanh_part, silu)
29+
tanh = torch.tanh(slope * (x - threshold)) + const_val
30+
return torch.where(x >= threshold, tanh, silu)
3131

3232

3333
def silut_backward(
3434
x: torch.Tensor, grad_output: torch.Tensor, threshold: float, slope: float
35-
):
35+
) -> torch.Tensor:
3636
sig = torch.sigmoid(x)
3737
grad_silu = sig * (1 + x * (1 - sig))
3838

39-
tanh_term = torch.tanh(slope * (x - threshold))
40-
grad_tanh = slope * (1 - tanh_term.pow(2))
39+
tanh = torch.tanh(slope * (x - threshold))
40+
grad_tanh = slope * (1 - tanh * tanh)
4141

4242
grad = torch.where(x >= threshold, grad_tanh, grad_silu)
43-
return grad * grad_output, grad
43+
return grad * grad_output
4444

4545

4646
def silut_double_backward(
@@ -49,19 +49,23 @@ def silut_double_backward(
4949
grad_output: torch.Tensor,
5050
threshold: float,
5151
slope: float,
52-
) -> torch.Tensor:
53-
# Tanh branch
54-
tanh_term = torch.tanh(slope * (x - threshold))
55-
grad_grad = -2 * slope * slope * tanh_term * (1 - tanh_term * tanh_term)
56-
52+
) -> tuple[torch.Tensor, torch.Tensor]:
5753
# SiLU branch
58-
sig = 1.0 / (1.0 + torch.exp(-x))
54+
sig = torch.sigmoid(x)
55+
5956
sig_prime = sig * (1 - sig)
60-
silu_term = sig_prime * (2 + x * (1 - 2 * sig))
57+
grad_silu = sig + x * sig_prime
58+
grad_grad_silu = sig_prime * (2 + x * (1 - 2 * sig))
6159

62-
grad_grad = torch.where(x >= threshold, grad_grad, silu_term)
60+
# Tanh branch
61+
tanh = torch.tanh(slope * (x - threshold))
62+
tanh_square = tanh * tanh # .square is slow for jit.script!
63+
grad_tanh = slope * (1 - tanh_square)
64+
grad_grad_tanh = -2 * slope * tanh * grad_tanh
6365

64-
return grad_output * grad_grad * grad_grad_output
66+
grad = torch.where(x >= threshold, grad_tanh, grad_silu)
67+
grad_grad = torch.where(x >= threshold, grad_grad_tanh, grad_grad_silu)
68+
return grad_output * grad_grad * grad_grad_output, grad * grad_grad_output
6569

6670

6771
class SiLUTScript(torch.nn.Module):
@@ -105,22 +109,20 @@ class SiLUTGradFunction(torch.autograd.Function):
105109
def forward(ctx, x, grad_output, threshold, slope):
106110
ctx.threshold = threshold
107111
ctx.slope = slope
108-
grad_input, grad = silut_backward_script(
109-
x, grad_output, threshold, slope
110-
)
111-
ctx.save_for_backward(x, grad_output, grad)
112+
grad_input = silut_backward_script(x, grad_output, threshold, slope)
113+
ctx.save_for_backward(x, grad_output)
112114
return grad_input
113115

114116
@staticmethod
115117
def backward(ctx, grad_grad_output):
116-
(x, grad_output, grad) = ctx.saved_tensors
118+
(x, grad_output) = ctx.saved_tensors
117119
threshold = ctx.threshold
118120
slope = ctx.slope
119121

120-
grad_input = silut_double_backward_script(
122+
grad_input, grad_mul_grad_grad_output = silut_double_backward_script(
121123
x, grad_grad_output, grad_output, threshold, slope
122124
)
123-
return grad_input, grad * grad_grad_output, None, None
125+
return grad_input, grad_mul_grad_grad_output, None, None
124126

125127
self.SiLUTFunction = SiLUTFunction
126128

0 commit comments

Comments
 (0)