diff --git a/Linear.lua b/Linear.lua index 31980bc06..7f6810d8d 100644 --- a/Linear.lua +++ b/Linear.lua @@ -39,7 +39,11 @@ function Linear:updateOutput(input) self.output:addmv(1, self.weight, input) elseif input:dim() == 2 then local nframe = input:size(1) + local nElement = self.output:nElement() self.output:resize(nframe, self.bias:size(1)) + if self.output:nElement() ~= nElement then + self.output:zero() + end if not self.addBuffer or self.addBuffer:nElement() ~= nframe then self.addBuffer = input.new(nframe):fill(1) end