forked from borisfom/cudnn.torch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBatchNormalization.lua
More file actions
133 lines (116 loc) · 4.62 KB
/
BatchNormalization.lua
File metadata and controls
133 lines (116 loc) · 4.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
local BatchNormalization, parent = torch.class('cudnn.BatchNormalization', 'nn.Module')
local ffi = require 'ffi'
local errcheck = cudnn.errcheck
BatchNormalization.mode = 'CUDNN_BATCHNORM_PER_ACTIVATION'
BatchNormalization.nDim = 2
function BatchNormalization:__init(nFeature, eps, momentum, affine)
parent.__init(self)
assert(nFeature and type(nFeature) == 'number',
'Missing argument #1: Number of feature planes. ')
assert(nFeature ~= 0, 'To set affine=false call BatchNormalization'
.. '(nFeature, eps, momentum, false) ')
assert(affine == true or affine == nil, 'only affine supported')
self.affine = true
self.eps = eps or 1e-5
self.train = true
self.momentum = momentum or 0.1
self.running_mean = torch.zeros(nFeature)
self.running_std = torch.ones(nFeature)
if self.affine then
self.weight = torch.Tensor(nFeature)
self.bias = torch.Tensor(nFeature)
self.gradWeight = torch.Tensor(nFeature)
self.gradBias = torch.Tensor(nFeature)
self:reset()
end
end
function BatchNormalization:reset()
if self.weight then
self.weight:uniform()
end
if self.bias then
self.bias:zero()
end
self.running_mean:zero()
self.running_std:fill(1)
end
function BatchNormalization:createIODescriptors(input)
assert(input:dim() == self.nDim)
assert(torch.typename(self.weight) == 'torch.CudaTensor' and torch.typename(self.bias) == 'torch.CudaTensor',
'Only CUDA tensors are supported for cudnn.BatchNormalization!')
if not self.iDesc or not self.oDesc or not input:isSize(self.iSize) then
local nFeature = self.running_mean:numel()
self.iSize = input:size()
self.output:resizeAs(input)
self.gradInput:resizeAs(input)
self.iDesc = cudnn.toDescriptor(input)
self.oDesc = cudnn.toDescriptor(self.output)
local biasSize = torch.ones(self.nDim):totable()
biasSize[2] = nFeature
self.sDesc = cudnn.toDescriptor(self.bias:view(table.unpack(biasSize)))
end
end
local one = torch.FloatTensor({1});
local zero = torch.FloatTensor({0});
local scaleTens = torch.FloatTensor(1);
function BatchNormalization:updateOutput(input)
self:createIODescriptors(input)
self.save_mean = self.save_mean or input.new()
self.save_mean:resizeAs(self.running_mean)
self.save_std = self.save_std or input.new()
self.save_std:resizeAs(self.running_std)
if self.train then
errcheck('cudnnBatchNormalizationForwardTraining',
cudnn.getHandle(), self.mode, one:data(), zero:data(),
self.iDesc[0], input:data(), self.oDesc[0], self.output:data(),
self.sDesc[0], self.weight:data(), self.bias:data(),
self.momentum, self.running_mean:data(), self.running_std:data(), self.eps, self.save_mean:data(), self.save_std:data());
else
errcheck('cudnnBatchNormalizationForwardInference',
cudnn.getHandle(), self.mode, one:data(), zero:data(),
self.iDesc[0], input:data(), self.oDesc[0], self.output:data(),
self.sDesc[0], self.weight:data(), self.bias:data(),
self.running_mean:data(), self.running_std:data(), self.eps);
end
return self.output
end
local function backward(self,input,gradOutput, scale)
assert(gradOutput:isContiguous())
self:createIODescriptors(input)
scale = scale or 1
scaleTens:fill(scale)
errcheck('cudnnBatchNormalizationBackward',
cudnn.getHandle(), self.mode, one:data(), zero:data(), scaleTens:data(), one:data(),
self.iDesc[0], input:data(), self.iDesc[0], gradOutput:data(), self.iDesc[0], self.gradInput:data(),
-- input is bottom, gradOutput is topDiff, self.gradInput is resultBottomDiff
self.sDesc[0], self.weight:data(), self.gradWeight:data(), self.gradBias:data(),
self.eps, self.save_mean:data(), self.save_std:data());
return self.gradInput
end
function BatchNormalization:updateGradInput(input, gradOutput, scale)
-- will in fact update gradWeight and gradBias too, accGradParameters call is empty
return backward(self, input, gradOutput, scale)
end
function BatchNormalization:backward(input, gradOutput, scale)
return backward(self, input, gradOutput, scale)
end
function BatchNormalization:accGradParameters(input, gradOutput, scale)
end
function BatchNormalization:clearDesc()
self.iDesc = nil
self.oDesc = nil
self.sDesc = nil
end
function BatchNormalization:write(f)
self:clearDesc()
local var = {}
for k,v in pairs(self) do
var[k] = v
end
f:writeObject(var)
end
function BatchNormalization:clearState()
self:clearDesc()
nn.utils.clear(self, 'save_mean', 'save_std')
return parent.clearState(self)
end