forked from borisfom/cudnn.torch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconvert.lua
More file actions
62 lines (58 loc) · 1.65 KB
/
convert.lua
File metadata and controls
62 lines (58 loc) · 1.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
-- modules that can be converted to nn seamlessly
local layer_list = {
'SpatialConvolution',
'SpatialCrossMapLRN',
'SpatialMaxPooling',
'SpatialAveragePooling',
'ReLU',
'Tanh',
'Sigmoid',
'SoftMax',
'LogSoftMax',
'VolumetricConvolution',
'VolumetricMaxPooling',
'VolumetricAveragePooling',
}
-- similar to nn.Module.apply
-- goes over a net and recursively replaces modules
-- using callback function
local function replace(self, callback)
local out = callback(self)
if self.modules then
for i, module in ipairs(self.modules) do
self.modules[i] = replace(module, callback)
end
end
return out
end
-- goes over a given net and converts all layers to dst backend
-- for example: net = cudnn.convert(net, cudnn)
function cudnn.convert(net, dst)
return replace(net, function(x)
local y = 0
local src = dst == nn and cudnn or nn
local src_prefix = src == nn and 'nn.' or 'cudnn.'
local dst_prefix = dst == nn and 'nn.' or 'cudnn.'
local function convert(v)
local y = {}
torch.setmetatable(y, dst_prefix..v)
if v == 'ReLU' then y = dst.ReLU() end -- because parameters
for k,u in pairs(x) do y[k] = u end
if src == cudnn and x.clearDesc then x:clearDesc() end
return y
end
local t = torch.typename(x)
if t == 'nn.SpatialConvolutionMM' then
y = convert('SpatialConvolution')
elseif t == 'inn.SpatialCrossResponseNormalization' then
y = convert('SpatialCrossMapLRN')
else
for i,v in ipairs(layer_list) do
if torch.typename(x) == src_prefix..v then
y = convert(v)
end
end
end
return y == 0 and x or y
end)
end