-
Notifications
You must be signed in to change notification settings - Fork 0
/
shape_flop_util.py
242 lines (192 loc) · 7.49 KB
/
shape_flop_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
from functools import reduce, singledispatch
from collections import namedtuple
from contextlib import contextmanager
import math
import operator
import torch.nn as nn
from nnsearch.pytorch.gated.module import BlockGatedConv3d, BlockGatedConv2d, BlockGatedFullyConnected
Flops = namedtuple("Flops", ["macc"])
@singledispatch
def output_shape(layer, input_shape):
""" Computes the output shape given a layer and input shape, without
evaluating the layer. Raises `NotImplementedError` for unsupported layer
types.
Parameters:
`layer` : The layer whose output shape is desired
`input_shape` : The shape of the input, in the format (N, H, W, ...). Note
that this must *not* include a "batch" dimension or anything similar.
"""
raise NotImplementedError(layer)
@output_shape.register(nn.Conv2d)
def _(layer, input_shape):
return _output_shape_Conv(2, input_shape, layer.out_channels,
layer.kernel_size, layer.stride, layer.padding, layer.dilation, False)
@output_shape.register(nn.Conv3d)
def _(layer, input_shape):
return _output_shape_Conv(3, input_shape, layer.out_channels,
layer.kernel_size, layer.stride, layer.padding, layer.dilation, False)
@output_shape.register(nn.Linear)
def _(layer, input_shape):
assert (flat_size(input_shape) == layer.in_features)
return tuple([layer.out_features])
@output_shape.register(nn.AvgPool2d)
def _(layer, input_shape):
out_channels = input_shape[0]
return _output_shape_Conv(2, input_shape, out_channels, layer.kernel_size,
layer.stride, layer.padding, 1, layer.ceil_mode)
@output_shape.register(nn.Softmax)
@output_shape.register(nn.BatchNorm2d)
@output_shape.register(nn.BatchNorm3d)
@output_shape.register(nn.ReLU)
@output_shape.register(nn.Dropout)
def _(layer, input_shape):
return input_shape
@output_shape.register(nn.MaxPool2d)
def _(layer, input_shape):
out_channels = input_shape[0]
return _output_shape_Conv(2, input_shape, out_channels, layer.kernel_size,
layer.stride, layer.padding, layer.dilation, layer.ceil_mode)
@output_shape.register(nn.MaxPool3d)
def _(layer, input_shape):
out_channels = input_shape[0]
return _output_shape_Conv(3, input_shape, out_channels,
layer.kernel_size, layer.stride, layer.padding, layer.dilation, False)
@output_shape.register(nn.Sequential)
def _(layer, input_shape):
for m in layer.children():
input_shape = output_shape(m, input_shape)
return input_shape
@output_shape.register(BlockGatedFullyConnected)
def _(layer, input_shape):
out_channels = input_shape[0]
# print("ssss", dir(layer.components[0]).out_channel)
out_channels = layer.components[0].out_features * len(layer.components)
return tuple([out_channels])
@output_shape.register(BlockGatedConv3d)
def _(layer, input_shape):
out_channels = input_shape[0]
# print("ssss", dir(layer.components[0]).out_channel)
out_channels= layer.components[0].out_channels * len(layer.components)
return _output_shape_Conv(3, input_shape, out_channels,
layer.components[0].kernel_size, layer.components[0].stride, layer.components[0].padding, layer.components[0].dilation, False)
@singledispatch
def flops(layer, in_shape):
raise NotImplementedError(layer)
@flops.register(nn.Sequential)
def _(layer, in_shape):
result = Flops(0)
for m in layer:
mf = flops(m, in_shape)
result = Flops(*(sum(x) for x in zip(mf, result)))
in_shape = output_shape(m, in_shape)
return result
@flops.register(nn.MaxPool3d)
def _(layer, in_shape):
return Flops(0)
@flops.register(nn.BatchNorm3d)
def _(layer, in_shape):
return Flops(0)
@flops.register(nn.Dropout)
def _(layer, in_shape):
return Flops(0)
@flops.register(nn.Conv2d)
def _(layer, in_shape):
out_shape = output_shape(layer, in_shape)
k = reduce(operator.mul, layer.kernel_size)
out_dim = reduce(operator.mul, out_shape)
macc = k * in_shape[0] * out_dim / layer.groups
return Flops(macc)
@flops.register(nn.Conv3d)
def _(layer, in_shape):
out_shape = output_shape(layer, in_shape)
k = reduce(operator.mul, layer.kernel_size)
out_dim = reduce(operator.mul, out_shape)
macc = k * in_shape[0] * out_dim / layer.groups
return Flops(macc)
@flops.register(nn.Linear)
def _(layer, in_shape):
# assert( flat_size(in_shape) == layer.in_features )
macc = layer.in_features * layer.out_features
return Flops(macc)
@flops.register(nn.ReLU)
def _(layer, in_shape):
return Flops(0)
def cat(xs, dim=0):
if len(xs) > 0:
if isinstance(xs[0], tuple):
xs = list(zip(*xs))
return TensorTuple([torch.cat(x, dim=dim) for x in xs])
return torch.cat(xs, dim=dim)
def flat_size(shape):
return reduce(operator.mul, shape)
def _maybe_expand_tuple(dim, tuple_or_int):
if type(tuple_or_int) is int:
tuple_or_int = tuple([tuple_or_int] * dim)
else:
assert (type(tuple_or_int) is tuple)
return tuple_or_int
def _output_shape_Conv(dim, input_shape, out_channels, kernel_size, stride,
padding, dilation, ceil_mode):
""" Implements output_shape for "conv-like" layers, including pooling layers.
"""
assert (len(input_shape) == dim + 1)
kernel_size = _maybe_expand_tuple(dim, kernel_size)
stride = _maybe_expand_tuple(dim, stride)
padding = _maybe_expand_tuple(dim, padding)
dilation = _maybe_expand_tuple(dim, dilation)
quantize = math.ceil if ceil_mode else math.floor
out_dim = [quantize(
(input_shape[i + 1] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1)
/ stride[i] + 1)
for i in range(dim)]
output_shape = tuple([out_channels] + out_dim)
return output_shape
class GlobalAvgPool2d(nn.Module):
def forward(self, x):
kernel_size = x.size()[2:]
y = fn.avg_pool2d(x, kernel_size)
while len(y.size()) > 2:
y = y.squeeze(-1)
return y
@output_shape.register(GlobalAvgPool2d)
def _(layer, input_shape):
out_channels = input_shape[0]
return (out_channels, 1, 1)
@output_shape.register(nn.MaxPool3d)
def _(layer, input_shape):
out_channels = input_shape[0]
return (out_channels, input_shape[1], input_shape[2])
@flops.register(GlobalAvgPool2d)
def _(layer, input_shape):
channels = input_shape[0]
n = input_shape[1] * input_shape[2]
# Call a division 4 flops, minus one for the sum
flops = channels * (n * n * 4 + (n * n - 1))
# divide by 2 because Flops actually represents MACCs (which is stupid btw)
return Flops(flops / 2)
@contextmanager
def printoptions_nowrap(restore={"profile": "default"}, **kwargs):
""" Convenience wrapper around 'printoptions' that sets 'linewidth' to its
maximum value.
"""
if "linewidth" in kwargs:
raise ValueError("user specified 'linewidth' overrides 'nowrap'")
torch.set_printoptions(linewidth=sys.maxsize, **kwargs)
yield
torch.set_printoptions(**restore)
def unsqueeze_right_as(x, reference):
while len(x.size()) < len(reference.size()):
x = x.unsqueeze(-1)
return x
def unsqueeze_left_as(x, reference):
while len(x.size()) < len(reference.size()):
x = x.unsqueeze(0)
return x
def unsqueeze_right_to_dim(x, dim):
while len(x.size()) < dim:
x = x.unsqueeze(-1)
return x
def unsqueeze_left_to_dim(x, dim):
while len(x.size()) < dim:
x = x.unsqueeze(0)
return x