Skip to content

Commit

Permalink
Merge pull request torch#345 from kosklain/master
Browse files Browse the repository at this point in the history
Fix contiguous gradoutput bug in nn.Mean
  • Loading branch information
soumith committed Aug 11, 2015
2 parents d72e7a6 + a1762b9 commit 0f5c1cc
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions Mean.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ function Mean:__init(dimension)
parent.__init(self)
dimension = dimension or 1
self.dimension = dimension
self._gradInput = torch.Tensor()
end

function Mean:updateOutput(input)
Expand All @@ -15,20 +16,13 @@ function Mean:updateOutput(input)
end

function Mean:updateGradInput(input, gradOutput)
local size = gradOutput:size():totable()
local stride = gradOutput:stride():totable()
self._gradInput:resizeAs(gradOutput):copy(gradOutput)
self._gradInput:mul(1/input:size(self.dimension))

if input:nDimension() > 1 then
table.insert(size, self.dimension, input:size(self.dimension))
table.insert(stride, self.dimension, 0)
else
size[1] = input:size(1)
stride[1] = 0
self._gradInput = nn.utils.addSingletonDimension(self._gradInput,
self.dimension)
end

self.gradInput:resizeAs(gradOutput):copy(gradOutput)
self.gradInput:mul(1/input:size(self.dimension))
self.gradInput:resize(torch.LongStorage(size), torch.LongStorage(stride))

self.gradInput = self._gradInput:expandAs(input)
return self.gradInput
end

0 comments on commit 0f5c1cc

Please sign in to comment.