18
18
from .env import PRECISION_DICT as PT_PRECISION_DICT
19
19
20
20
21
+ class CustomSilu (torch .nn .Module ):
22
+ def __init__ (self , threshold = 3.0 ):
23
+ super ().__init__ ()
24
+
25
+ def sigmoid (x ):
26
+ return 1 / (1 + np .exp (- x ))
27
+
28
+ def silu (x ):
29
+ return x * sigmoid (x )
30
+
31
+ def silu_grad (x ):
32
+ sig = sigmoid (x )
33
+ return sig + x * sig * (1 - sig )
34
+
35
+ self .threshold = threshold
36
+ self .slope = float (silu_grad (threshold ))
37
+ self .const = float (silu (threshold ))
38
+
39
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
40
+ silu_part = F .silu (x )
41
+ mask = x > self .threshold
42
+ if torch .any (mask ):
43
+ tanh_part = torch .tanh (self .slope * (x - self .threshold )) + self .const
44
+ return torch .where (x < self .threshold , silu_part , tanh_part )
45
+ else :
46
+ return silu_part
47
+
48
+
21
49
class ActivationFn (torch .nn .Module ):
22
50
def __init__ (self , activation : Optional [str ]) -> None :
23
51
super ().__init__ ()
24
52
self .activation : str = activation if activation is not None else "linear"
53
+ if self .activation .startswith ("custom_silu" ):
54
+ threshold = (
55
+ float (self .activation .split (":" )[- 1 ]) if ":" in self .activation else 3.0
56
+ )
57
+ self .custom_silu = CustomSilu (threshold = threshold )
58
+ else :
59
+ self .custom_silu = None
25
60
26
61
def forward (self , x : torch .Tensor ) -> torch .Tensor :
27
62
"""Returns the tensor after applying activation function corresponding to `activation`."""
@@ -41,6 +76,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
76
return torch .sigmoid (x )
42
77
elif self .activation .lower () == "silu" :
43
78
return F .silu (x )
79
+ elif self .activation .startswith ("custom_silu" ):
80
+ assert self .custom_silu is not None
81
+ return self .custom_silu (x )
44
82
elif self .activation .lower () == "linear" or self .activation .lower () == "none" :
45
83
return x
46
84
else :
0 commit comments