9
9
10
10
11
11
class Lambda (CoModule , nn .Module ):
12
- """Module wrapper for stateless functions
12
+ """Module wrapper for stateless functions.
13
13
14
14
NB: Operations performed in a Lambda are not counted in `ptflops`
15
+
16
+ Args:
17
+ fn (Callable[[Tensor], Tensor]): Function to be called during forward.
18
+ takes_time (bool, optional): If True, `fn` recieves all steps, if False, it received one step and no time dimension. Defaults to False.
15
19
"""
16
20
17
- def __init__ (self , fn : Callable [[Tensor ], Tensor ], unsqueeze_step = True ):
21
+ def __init__ (self , fn : Callable [[Tensor ], Tensor ], takes_time = False ):
18
22
nn .Module .__init__ (self )
19
23
assert callable (fn ), "The pased function should be callable."
20
24
self .fn = fn
21
- self .unsqueeze_step = unsqueeze_step
25
+ self .takes_time = takes_time
22
26
23
27
def __repr__ (self ) -> str :
24
28
s = self .fn .__name__
@@ -47,26 +51,27 @@ def __repr__(self) -> str:
47
51
return f"Lambda({ s } )"
48
52
49
53
def forward (self , input : Tensor ) -> Tensor :
50
- return self .fn (input )
54
+ if self .takes_time :
55
+ return self .fn (input )
56
+
57
+ return torch .stack (
58
+ [self .fn (input [:, :, t ]) for t in range (input .shape [2 ])], dim = 2
59
+ )
60
+
61
+ def forward_steps (self , input : Tensor , pad_end = False , update_state = True ) -> Tensor :
62
+ return self .forward (input )
51
63
52
64
def forward_step (self , input : Tensor , update_state = True ) -> Tensor :
53
- if self .unsqueeze_step :
65
+ if self .takes_time :
54
66
input = input .unsqueeze (dim = 2 )
55
67
output = self .fn (input )
56
- if self .unsqueeze_step :
68
+ if self .takes_time :
57
69
output = output .squeeze (dim = 2 )
58
70
return output
59
71
60
- def forward_steps (self , input : Tensor , pad_end = False , update_state = True ) -> Tensor :
61
- return self .fn (input )
62
-
63
- @property
64
- def delay (self ) -> int :
65
- return 0
66
-
67
72
@staticmethod
68
- def build_from (fn : Callable [[Tensor ], Tensor ]) -> "Lambda" :
69
- return Lambda (fn )
73
+ def build_from (fn : Callable [[Tensor ], Tensor ], takes_time = False ) -> "Lambda" :
74
+ return Lambda (fn , takes_time )
70
75
71
76
72
77
def _multiply (x : Tensor , factor : Union [float , int , Tensor ]):
@@ -76,7 +81,7 @@ def _multiply(x: Tensor, factor: Union[float, int, Tensor]):
76
81
def Multiply (factor ) -> Lambda :
77
82
"""Create Lambda with multiplication function"""
78
83
fn = partial (_multiply , factor = factor )
79
- return Lambda (fn )
84
+ return Lambda (fn , takes_time = True )
80
85
81
86
82
87
def _add (x : Tensor , constant : Union [float , int , Tensor ]):
@@ -86,7 +91,7 @@ def _add(x: Tensor, constant: Union[float, int, Tensor]):
86
91
def Add (constant ) -> Lambda :
87
92
"""Create Lambda with addition function"""
88
93
fn = partial (_add , constant = constant )
89
- return Lambda (fn )
94
+ return Lambda (fn , takes_time = True )
90
95
91
96
92
97
def _unity (x : Tensor ):
@@ -95,18 +100,18 @@ def _unity(x: Tensor):
95
100
96
101
def Unity () -> Lambda :
97
102
"""Create Lambda with addition function"""
98
- return Lambda (_unity )
103
+ return Lambda (_unity , takes_time = True )
99
104
100
105
101
106
def Constant (constant : float ):
102
- return Lambda (lambda x : constant * torch .ones_like (x ))
107
+ return Lambda (lambda x : constant * torch .ones_like (x ), takes_time = True )
103
108
104
109
105
110
def Zero () -> Lambda :
106
111
"""Create Lambda with zero output"""
107
- return Lambda (torch .zeros_like )
112
+ return Lambda (torch .zeros_like , takes_time = True )
108
113
109
114
110
115
def One () -> Lambda :
111
116
"""Create Lambda with zero output"""
112
- return Lambda (torch .ones_like )
117
+ return Lambda (torch .ones_like , takes_time = True )
0 commit comments