diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua index bc2a2e98d..9b8d0aa5e 100644 --- a/ClassNLLCriterion.lua +++ b/ClassNLLCriterion.lua @@ -1,16 +1,28 @@ -local ClassNLLCriterion, parent = torch.class('nn.ClassNLLCriterion', 'nn.Criterion') +local ClassNLLCriterion, parent = torch.class( + 'nn.ClassNLLCriterion', + 'nn.Criterion' +) -function ClassNLLCriterion:__init(weights) - parent.__init(self) - self.sizeAverage = true - self.outputTensor = torch.Tensor(1) - if weights then - assert(weights:dim() == 1, "weights input should be 1-D Tensor") - self.weights = weights - end +function ClassNLLCriterion:__init(weights, sizeAverage) + parent.__init(self) + if sizeAverage ~= nil then + self.sizeAverage = sizeAverage + else + self.sizeAverage = true + end + if weights then + assert(weights:dim() == 1, "weights input should be 1-D Tensor") + self.weights = weights + end + + self.output_tensor = torch.zeros(1) + self.total_weight_tensor = torch.zeros(1) + self.target = torch.zeros(1):long() end + + function ClassNLLCriterion:__len() if (self.weights) then return #self.weights @@ -21,101 +33,43 @@ end function ClassNLLCriterion:updateOutput(input, target) - if input:type() == 'torch.CudaTensor' then - if self.weights == nil then - -- The CUDA implementation requires self.weights be non-nil - self.weights = torch.CudaTensor() - end - assert(self.weights:dim() == 0 or self.weights:dim() == 1, - 'weights must be 1D or empty') - -- The cuda code wont check weight size, so we must do it here. - if self.weights:dim() == 1 then - if input:dim() == 1 then - assert(self.weights:size(1) == input:size(1), - 'Wrong number of weights') - else - assert(self.weights:size(1) == input:size(2), - 'Wrong number of weights') - end - end - if input:dim() == 1 then - self._target = self._target or input.new(1) - if type(target) == 'number' then - self._target[1] = target - else - self._target:copy(target) - end - input.nn.ClassNLLCriterion_updateOutput(self, input, self._target) - else - input.nn.ClassNLLCriterion_updateOutput(self, input, target) - end - self.output = self.outputTensor[1] - return self.output - end + if type(target) == 'number' then + self.target[1] = target + elseif target:type() == 'torch.CudaTensor' then + self.target = target + else + self.target = target:long() + end - if input:dim() == 1 then - if torch.isTensor(target) then target = target[1] end - self.output = -input[target] - if self.weights then - self.output = self.output*self.weights[target] - end - elseif input:dim() == 2 then - local output = 0 - for i=1,target:size(1) do - if self.weights then - output = output - input[i][target[i]]*self.weights[target[i]] - else - output = output - input[i][target[i]] - end - end - if self.sizeAverage then - output = output / target:size(1) - end - self.output = output - else - error('matrix or vector expected') - end - return self.output + input.nn.ClassNLLCriterion_updateOutput( + input, + self.target, + self.weights, + self.sizeAverage, + self.output_tensor, + self.total_weight_tensor + ) + self.output = self.output_tensor[1] + return self.output, self.total_weight_tensor[1] end function ClassNLLCriterion:updateGradInput(input, target) - self.gradInput:resizeAs(input) - self.gradInput:zero() + if type(target) == 'number' then + self.target[1] = target + elseif target:type() == 'torch.CudaTensor' then + self.target = target + else + self.target = target:long() + end - if input:type() == 'torch.CudaTensor' then - -- Note: we'll assume that updateOutput() has been called and self.weights - -- is non-nil. - if input:dim() == 1 then - self._target = self._target or input.new(1) - if type(target) == 'number' then - self._target[1] = target - else - self._target:copy(target) - end - input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target) - else - input.nn.ClassNLLCriterion_updateGradInput(self, input, target) - end - return self.gradInput - end - - if input:dim() == 1 then - if torch.isTensor(target) then target = target[1] end - self.gradInput[target] = -1 - if self.weights then - self.gradInput[target] = self.gradInput[target]*self.weights[target] - end - else - local z = -1 - if self.sizeAverage then - z = z / target:size(1) - end - for i=1,target:size(1) do - self.gradInput[i][target[i]] = z - if self.weights then - self.gradInput[i][target[i]] = self.gradInput[i][target[i]]*self.weights[target[i]] - end - end - end - return self.gradInput + self.gradInput:resizeAs(input):zero() + input.nn.ClassNLLCriterion_updateGradInput( + input, + self.target, + self.weights, + self.sizeAverage, + self.total_weight_tensor, + self.gradInput + ) + return self.gradInput end diff --git a/generic/ClassNLLCriterion.c b/generic/ClassNLLCriterion.c new file mode 100644 index 000000000..d8efef76f --- /dev/null +++ b/generic/ClassNLLCriterion.c @@ -0,0 +1,163 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/ClassNLLCriterion.c" +#else + + +static int nn_(ClassNLLCriterion_updateOutput)(lua_State *L) +{ + THTensor *input = luaT_checkudata(L, 1, torch_Tensor); + THLongTensor *target = luaT_checkudata(L, 2, "torch.LongTensor"); + THTensor *weights = NULL; + if (!lua_isnil(L, 3)) { + weights = luaT_checkudata(L, 3, torch_Tensor); + } + int n_dims = THTensor_(nDimension)(input); + int n_classes = THTensor_(size)(input, n_dims - 1); + + int sizeAverage = lua_toboolean(L, 4); + THTensor *output = luaT_checkudata(L, 5, torch_Tensor); + THTensor *total_weight = luaT_checkudata(L, 6, torch_Tensor); + + if (THLongTensor_nDimension(target) > 1) { + THError("multi-target not supported"); + } + if (THTensor_(nDimension)(input) > 2) { + THError("input tensor should be 1D or 2D"); + } + + input = THTensor_(newContiguous)(input); + target = THLongTensor_newContiguous(target); + weights = weights ? THTensor_(newContiguous)(weights) : NULL; + + real *input_data = THTensor_(data)(input); + long *target_data = THLongTensor_data(target); + real *weights_data = weights ? THTensor_(data)(weights) : NULL; + real *output_data = THTensor_(data)(output); + real *total_weight_data = THTensor_(data)(total_weight); + + output_data[0] = total_weight_data[0] = 0.0; + + if (THTensor_(nDimension)(input) == 1) { + int cur_target = target_data[0] - 1; + THAssert(cur_target >= 0 && cur_target < n_classes); + total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f; + output_data[0] = -input_data[cur_target] * total_weight_data[0]; + } else if (THTensor_(nDimension)(input) == 2) { + int batch_size = THTensor_(size)(input, 0); + int n_target = THTensor_(size)(input, 1); + + int i; + for (i = 0; i < batch_size; i++) { + int cur_target = target_data[i] - 1; + THAssert(cur_target >= 0 && cur_target < n_classes); + + real cur_weight = weights ? weights_data[cur_target] : 1.0f; + total_weight_data[0] += cur_weight; + output_data[0] -= input_data[i * n_target + cur_target] * cur_weight; + } + } + + if (sizeAverage && total_weight_data[0]) { + output_data[0] /= total_weight_data[0]; + } + + if (weights) { + THTensor_(free)(weights); + } + THTensor_(free)(input); + THLongTensor_free(target); + + return 0; +} + +static int nn_(ClassNLLCriterion_updateGradInput)(lua_State *L) +{ + THTensor *input = luaT_checkudata(L, 1, torch_Tensor); + THLongTensor *target = luaT_checkudata(L, 2, "torch.LongTensor"); + THTensor *weights = NULL; + if (!lua_isnil(L, 3)) { + weights = luaT_checkudata(L, 3, torch_Tensor); + } + + int n_dims = THTensor_(nDimension)(input); + int n_classes = THTensor_(size)(input, n_dims - 1); + + int sizeAverage = lua_toboolean(L, 4); + THTensor *total_weight = luaT_checkudata(L, 5, torch_Tensor); + THTensor *gradInput = luaT_checkudata(L, 6, torch_Tensor); + luaL_argcheck( + L, + THTensor_(isContiguous)(gradInput), + 6, + "gradInput must be contiguous" + ); + + real* total_weight_data = THTensor_(data)(total_weight); + + if (!(*total_weight_data > 0)) { + return 0; + } + + if (THLongTensor_nDimension(target) > 1) { + THError("multi-target not supported"); + } + + if (THTensor_(nDimension)(input) > 2) { + THError("input tensor should be 1D or 2D"); + } + + target = THLongTensor_newContiguous(target); + weights = weights ? THTensor_(newContiguous)(weights) : NULL; + + long *target_data = THLongTensor_data(target); + real *weights_data = weights ? THTensor_(data)(weights) : NULL; + real *gradInput_data = THTensor_(data)(gradInput); + + if (THTensor_(nDimension)(input) == 1) { + int cur_target = target_data[0] - 1; + THAssert(cur_target >= 0 && cur_target < n_classes); + + gradInput_data[cur_target] = + (!sizeAverage && weights) ? -weights_data[cur_target] : -1; + + } else if (THTensor_(nDimension)(input) == 2) { + int batch_size = THTensor_(size)(input, 0); + int n_target = THTensor_(size)(input, 1); + + int i; + for(i = 0; i < batch_size; i++){ + int cur_target = target_data[i] - 1; + + THAssert(cur_target >= 0 && cur_target < n_classes); + + gradInput_data[i * n_target + cur_target] = + -(weights ? weights_data[cur_target] : 1.0f); + + if (sizeAverage && *total_weight_data) { + gradInput_data[i * n_target + cur_target] /= *total_weight_data; + } + } + } + + THLongTensor_free(target); + if (weights) { + THTensor_(free)(weights); + } + + return 0; +} + +static const struct luaL_Reg nn_(ClassNLLCriterion__) [] = { + {"ClassNLLCriterion_updateOutput", nn_(ClassNLLCriterion_updateOutput)}, + {"ClassNLLCriterion_updateGradInput", nn_(ClassNLLCriterion_updateGradInput)}, + {NULL, NULL} +}; + +static void nn_(ClassNLLCriterion_init)(lua_State *L) +{ + luaT_pushmetatable(L, torch_Tensor); + luaT_registeratname(L, nn_(ClassNLLCriterion__), "nn"); + lua_pop(L,1); +} + +#endif diff --git a/init.c b/init.c index 7cdae6967..0fc72085d 100644 --- a/init.c +++ b/init.c @@ -47,6 +47,9 @@ #include "generic/SoftMax.c" #include "THGenerateFloatTypes.h" +#include "generic/ClassNLLCriterion.c" +#include "THGenerateFloatTypes.h" + #include "generic/MSECriterion.c" #include "THGenerateFloatTypes.h" @@ -134,6 +137,7 @@ int luaopen_libnn(lua_State *L) nn_FloatSquare_init(L); nn_FloatHardTanh_init(L); nn_FloatLogSoftMax_init(L); + nn_FloatClassNLLCriterion_init(L); nn_FloatMSECriterion_init(L); nn_FloatMarginCriterion_init(L); nn_FloatAbsCriterion_init(L); @@ -174,6 +178,7 @@ int luaopen_libnn(lua_State *L) nn_DoubleSquare_init(L); nn_DoubleHardTanh_init(L); nn_DoubleLogSoftMax_init(L); + nn_DoubleClassNLLCriterion_init(L); nn_DoubleMSECriterion_init(L); nn_DoubleMarginCriterion_init(L); nn_DoubleAbsCriterion_init(L);