Skip to content

Commit cc7a43f

Browse files
committed
add CustomSilu
1 parent 0d17776 commit cc7a43f

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

deepmd/pt/utils/utils.py

+38
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,45 @@
1818
from .env import PRECISION_DICT as PT_PRECISION_DICT
1919

2020

21+
class CustomSilu(torch.nn.Module):
22+
def __init__(self, threshold=3.0):
23+
super().__init__()
24+
25+
def sigmoid(x):
26+
return 1 / (1 + np.exp(-x))
27+
28+
def silu(x):
29+
return x * sigmoid(x)
30+
31+
def silu_grad(x):
32+
sig = sigmoid(x)
33+
return sig + x * sig * (1 - sig)
34+
35+
self.threshold = threshold
36+
self.slope = float(silu_grad(threshold))
37+
self.const = float(silu(threshold))
38+
39+
def forward(self, x: torch.Tensor) -> torch.Tensor:
40+
silu_part = F.silu(x)
41+
mask = x > self.threshold
42+
if torch.any(mask):
43+
tanh_part = torch.tanh(self.slope * (x - self.threshold)) + self.const
44+
return torch.where(x < self.threshold, silu_part, tanh_part)
45+
else:
46+
return silu_part
47+
48+
2149
class ActivationFn(torch.nn.Module):
2250
def __init__(self, activation: Optional[str]) -> None:
2351
super().__init__()
2452
self.activation: str = activation if activation is not None else "linear"
53+
if self.activation.startswith("custom_silu"):
54+
threshold = (
55+
float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0
56+
)
57+
self.custom_silu = CustomSilu(threshold=threshold)
58+
else:
59+
self.custom_silu = None
2560

2661
def forward(self, x: torch.Tensor) -> torch.Tensor:
2762
"""Returns the tensor after applying activation function corresponding to `activation`."""
@@ -41,6 +76,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4176
return torch.sigmoid(x)
4277
elif self.activation.lower() == "silu":
4378
return F.silu(x)
79+
elif self.activation.startswith("custom_silu"):
80+
assert self.custom_silu is not None
81+
return self.custom_silu(x)
4482
elif self.activation.lower() == "linear" or self.activation.lower() == "none":
4583
return x
4684
else:

0 commit comments

Comments
 (0)