forked from torch/nn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
View.lua
87 lines (74 loc) · 2.18 KB
/
View.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
local View, parent = torch.class('nn.View', 'nn.Module')
function View:__init(...)
parent.__init(self)
if select('#', ...) == 1 and torch.typename(select(1, ...)) == 'torch.LongStorage' then
self.size = select(1, ...)
else
self.size = torch.LongStorage({...})
end
self.numElements = 1
local inferdim = false
for i = 1,#self.size do
local szi = self.size[i]
if szi >= 0 then
self.numElements = self.numElements * self.size[i]
else
assert(szi == -1, 'size should be positive or -1')
assert(not inferdim, 'only one dimension can be at -1')
inferdim = true
end
end
self.output = nil
self.gradInput = nil
self.numInputDims = nil
end
function View:setNumInputDims(numInputDims)
self.numInputDims = numInputDims
return self
end
local function batchsize(input, size, numInputDims, numElements)
local ind = input:nDimension()
local isz = input:size()
local maxdim = numInputDims and numInputDims or ind
local ine = 1
for i=ind,ind-maxdim+1,-1 do
ine = ine * isz[i]
end
if ine % numElements ~= 0 then
error(string.format(
'input view (%s) and desired view (%s) do not match',
table.concat(input:size():totable(), 'x'),
table.concat(size:totable(), 'x')))
end
-- the remainder is either the batch...
local bsz = ine / numElements
-- ... or the missing size dim
for i=1,size:size() do
if size[i] == -1 then
bsz = 1
break
end
end
-- for dim over maxdim, it is definitively the batch
for i=ind-maxdim,1,-1 do
bsz = bsz * isz[i]
end
-- special card
if bsz == 1 and (not numInputDims or input:nDimension() <= numInputDims) then
return
end
return bsz
end
function View:updateOutput(input)
local bsz = batchsize(input, self.size, self.numInputDims, self.numElements)
if bsz then
self.output = input:view(bsz, table.unpack(self.size:totable()))
else
self.output = input:view(self.size)
end
return self.output
end
function View:updateGradInput(input, gradOutput)
self.gradInput = gradOutput:view(input:size())
return self.gradInput
end