Skip to content

Commit

Permalink
Merge pull request torch#353 from adamlerer/class_nll_criterion
Browse files Browse the repository at this point in the history
Improve ClassNLLCriterion
  • Loading branch information
soumith committed Aug 25, 2015
2 parents b87b263 + ef2e3d2 commit 609fb78
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 101 deletions.
156 changes: 55 additions & 101 deletions ClassNLLCriterion.lua
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
local ClassNLLCriterion, parent = torch.class('nn.ClassNLLCriterion', 'nn.Criterion')
local ClassNLLCriterion, parent = torch.class(
'nn.ClassNLLCriterion',
'nn.Criterion'
)

function ClassNLLCriterion:__init(weights)
parent.__init(self)
self.sizeAverage = true
self.outputTensor = torch.Tensor(1)
if weights then
assert(weights:dim() == 1, "weights input should be 1-D Tensor")
self.weights = weights
end
function ClassNLLCriterion:__init(weights, sizeAverage)
parent.__init(self)
if sizeAverage ~= nil then
self.sizeAverage = sizeAverage
else
self.sizeAverage = true
end
if weights then
assert(weights:dim() == 1, "weights input should be 1-D Tensor")
self.weights = weights
end

self.output_tensor = torch.zeros(1)
self.total_weight_tensor = torch.zeros(1)
self.target = torch.zeros(1):long()
end




function ClassNLLCriterion:__len()
if (self.weights) then
return #self.weights
Expand All @@ -21,101 +33,43 @@ end


function ClassNLLCriterion:updateOutput(input, target)
if input:type() == 'torch.CudaTensor' then
if self.weights == nil then
-- The CUDA implementation requires self.weights be non-nil
self.weights = torch.CudaTensor()
end
assert(self.weights:dim() == 0 or self.weights:dim() == 1,
'weights must be 1D or empty')
-- The cuda code wont check weight size, so we must do it here.
if self.weights:dim() == 1 then
if input:dim() == 1 then
assert(self.weights:size(1) == input:size(1),
'Wrong number of weights')
else
assert(self.weights:size(1) == input:size(2),
'Wrong number of weights')
end
end
if input:dim() == 1 then
self._target = self._target or input.new(1)
if type(target) == 'number' then
self._target[1] = target
else
self._target:copy(target)
end
input.nn.ClassNLLCriterion_updateOutput(self, input, self._target)
else
input.nn.ClassNLLCriterion_updateOutput(self, input, target)
end
self.output = self.outputTensor[1]
return self.output
end
if type(target) == 'number' then
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
self.target = target
else
self.target = target:long()
end

if input:dim() == 1 then
if torch.isTensor(target) then target = target[1] end
self.output = -input[target]
if self.weights then
self.output = self.output*self.weights[target]
end
elseif input:dim() == 2 then
local output = 0
for i=1,target:size(1) do
if self.weights then
output = output - input[i][target[i]]*self.weights[target[i]]
else
output = output - input[i][target[i]]
end
end
if self.sizeAverage then
output = output / target:size(1)
end
self.output = output
else
error('matrix or vector expected')
end
return self.output
input.nn.ClassNLLCriterion_updateOutput(
input,
self.target,
self.weights,
self.sizeAverage,
self.output_tensor,
self.total_weight_tensor
)
self.output = self.output_tensor[1]
return self.output, self.total_weight_tensor[1]
end

function ClassNLLCriterion:updateGradInput(input, target)
self.gradInput:resizeAs(input)
self.gradInput:zero()
if type(target) == 'number' then
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
self.target = target
else
self.target = target:long()
end

if input:type() == 'torch.CudaTensor' then
-- Note: we'll assume that updateOutput() has been called and self.weights
-- is non-nil.
if input:dim() == 1 then
self._target = self._target or input.new(1)
if type(target) == 'number' then
self._target[1] = target
else
self._target:copy(target)
end
input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target)
else
input.nn.ClassNLLCriterion_updateGradInput(self, input, target)
end
return self.gradInput
end

if input:dim() == 1 then
if torch.isTensor(target) then target = target[1] end
self.gradInput[target] = -1
if self.weights then
self.gradInput[target] = self.gradInput[target]*self.weights[target]
end
else
local z = -1
if self.sizeAverage then
z = z / target:size(1)
end
for i=1,target:size(1) do
self.gradInput[i][target[i]] = z
if self.weights then
self.gradInput[i][target[i]] = self.gradInput[i][target[i]]*self.weights[target[i]]
end
end
end
return self.gradInput
self.gradInput:resizeAs(input):zero()
input.nn.ClassNLLCriterion_updateGradInput(
input,
self.target,
self.weights,
self.sizeAverage,
self.total_weight_tensor,
self.gradInput
)
return self.gradInput
end
163 changes: 163 additions & 0 deletions generic/ClassNLLCriterion.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/ClassNLLCriterion.c"
#else


static int nn_(ClassNLLCriterion_updateOutput)(lua_State *L)
{
THTensor *input = luaT_checkudata(L, 1, torch_Tensor);
THLongTensor *target = luaT_checkudata(L, 2, "torch.LongTensor");
THTensor *weights = NULL;
if (!lua_isnil(L, 3)) {
weights = luaT_checkudata(L, 3, torch_Tensor);
}
int n_dims = THTensor_(nDimension)(input);
int n_classes = THTensor_(size)(input, n_dims - 1);

int sizeAverage = lua_toboolean(L, 4);
THTensor *output = luaT_checkudata(L, 5, torch_Tensor);
THTensor *total_weight = luaT_checkudata(L, 6, torch_Tensor);

if (THLongTensor_nDimension(target) > 1) {
THError("multi-target not supported");
}
if (THTensor_(nDimension)(input) > 2) {
THError("input tensor should be 1D or 2D");
}

input = THTensor_(newContiguous)(input);
target = THLongTensor_newContiguous(target);
weights = weights ? THTensor_(newContiguous)(weights) : NULL;

real *input_data = THTensor_(data)(input);
long *target_data = THLongTensor_data(target);
real *weights_data = weights ? THTensor_(data)(weights) : NULL;
real *output_data = THTensor_(data)(output);
real *total_weight_data = THTensor_(data)(total_weight);

output_data[0] = total_weight_data[0] = 0.0;

if (THTensor_(nDimension)(input) == 1) {
int cur_target = target_data[0] - 1;
THAssert(cur_target >= 0 && cur_target < n_classes);
total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f;
output_data[0] = -input_data[cur_target] * total_weight_data[0];
} else if (THTensor_(nDimension)(input) == 2) {
int batch_size = THTensor_(size)(input, 0);
int n_target = THTensor_(size)(input, 1);

int i;
for (i = 0; i < batch_size; i++) {
int cur_target = target_data[i] - 1;
THAssert(cur_target >= 0 && cur_target < n_classes);

real cur_weight = weights ? weights_data[cur_target] : 1.0f;
total_weight_data[0] += cur_weight;
output_data[0] -= input_data[i * n_target + cur_target] * cur_weight;
}
}

if (sizeAverage && total_weight_data[0]) {
output_data[0] /= total_weight_data[0];
}

if (weights) {
THTensor_(free)(weights);
}
THTensor_(free)(input);
THLongTensor_free(target);

return 0;
}

static int nn_(ClassNLLCriterion_updateGradInput)(lua_State *L)
{
THTensor *input = luaT_checkudata(L, 1, torch_Tensor);
THLongTensor *target = luaT_checkudata(L, 2, "torch.LongTensor");
THTensor *weights = NULL;
if (!lua_isnil(L, 3)) {
weights = luaT_checkudata(L, 3, torch_Tensor);
}

int n_dims = THTensor_(nDimension)(input);
int n_classes = THTensor_(size)(input, n_dims - 1);

int sizeAverage = lua_toboolean(L, 4);
THTensor *total_weight = luaT_checkudata(L, 5, torch_Tensor);
THTensor *gradInput = luaT_checkudata(L, 6, torch_Tensor);
luaL_argcheck(
L,
THTensor_(isContiguous)(gradInput),
6,
"gradInput must be contiguous"
);

real* total_weight_data = THTensor_(data)(total_weight);

if (!(*total_weight_data > 0)) {
return 0;
}

if (THLongTensor_nDimension(target) > 1) {
THError("multi-target not supported");
}

if (THTensor_(nDimension)(input) > 2) {
THError("input tensor should be 1D or 2D");
}

target = THLongTensor_newContiguous(target);
weights = weights ? THTensor_(newContiguous)(weights) : NULL;

long *target_data = THLongTensor_data(target);
real *weights_data = weights ? THTensor_(data)(weights) : NULL;
real *gradInput_data = THTensor_(data)(gradInput);

if (THTensor_(nDimension)(input) == 1) {
int cur_target = target_data[0] - 1;
THAssert(cur_target >= 0 && cur_target < n_classes);

gradInput_data[cur_target] =
(!sizeAverage && weights) ? -weights_data[cur_target] : -1;

} else if (THTensor_(nDimension)(input) == 2) {
int batch_size = THTensor_(size)(input, 0);
int n_target = THTensor_(size)(input, 1);

int i;
for(i = 0; i < batch_size; i++){
int cur_target = target_data[i] - 1;

THAssert(cur_target >= 0 && cur_target < n_classes);

gradInput_data[i * n_target + cur_target] =
-(weights ? weights_data[cur_target] : 1.0f);

if (sizeAverage && *total_weight_data) {
gradInput_data[i * n_target + cur_target] /= *total_weight_data;
}
}
}

THLongTensor_free(target);
if (weights) {
THTensor_(free)(weights);
}

return 0;
}

static const struct luaL_Reg nn_(ClassNLLCriterion__) [] = {
{"ClassNLLCriterion_updateOutput", nn_(ClassNLLCriterion_updateOutput)},
{"ClassNLLCriterion_updateGradInput", nn_(ClassNLLCriterion_updateGradInput)},
{NULL, NULL}
};

static void nn_(ClassNLLCriterion_init)(lua_State *L)
{
luaT_pushmetatable(L, torch_Tensor);
luaT_registeratname(L, nn_(ClassNLLCriterion__), "nn");
lua_pop(L,1);
}

#endif
5 changes: 5 additions & 0 deletions init.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
#include "generic/SoftMax.c"
#include "THGenerateFloatTypes.h"

#include "generic/ClassNLLCriterion.c"
#include "THGenerateFloatTypes.h"

#include "generic/MSECriterion.c"
#include "THGenerateFloatTypes.h"

Expand Down Expand Up @@ -134,6 +137,7 @@ int luaopen_libnn(lua_State *L)
nn_FloatSquare_init(L);
nn_FloatHardTanh_init(L);
nn_FloatLogSoftMax_init(L);
nn_FloatClassNLLCriterion_init(L);
nn_FloatMSECriterion_init(L);
nn_FloatMarginCriterion_init(L);
nn_FloatAbsCriterion_init(L);
Expand Down Expand Up @@ -174,6 +178,7 @@ int luaopen_libnn(lua_State *L)
nn_DoubleSquare_init(L);
nn_DoubleHardTanh_init(L);
nn_DoubleLogSoftMax_init(L);
nn_DoubleClassNLLCriterion_init(L);
nn_DoubleMSECriterion_init(L);
nn_DoubleMarginCriterion_init(L);
nn_DoubleAbsCriterion_init(L);
Expand Down

0 comments on commit 609fb78

Please sign in to comment.