diff --git a/explainer/ebp/functions.py b/explainer/ebp/functions.py index 379f942..4336887 100644 --- a/explainer/ebp/functions.py +++ b/explainer/ebp/functions.py @@ -6,7 +6,7 @@ class EBLinear(Function): @staticmethod - def forward(ctx, inp, weight, bias=None): + def forward(ctx, inp, weight, bias=None): ctx.save_for_backward(inp, weight, bias) output = inp.matmul(weight.t()) if bias is not None: @@ -30,16 +30,23 @@ def backward(ctx, grad_output): def _output_size(inp, weight, pad, dilation, stride): - pad = pad[0] - dilation = dilation[0] - stride = stride[0] + + #if any are 1 dim + if len(pad) == 1: + pad = [pad[0] for _ in inp.dim()-2] + if len(dilation)==1: + dilation = [dilation[0] for _ in inp.dim()-2] + if len(stride)==1: + stride = [stride[0] for _ in inp.dim()-2] channels = weight.size(0) + output_size = (inp.size(0), channels) for d in range(inp.dim() - 2): in_size = inp.size(d + 2) - kernel = dilation * (weight.size(d + 2) - 1) + 1 - output_size += ((in_size + (2 * pad) - kernel) // stride + 1,) + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + output_size += ((in_size + (2 * pad[d]) - kernel) // stride[d] + 1,) + if not all(map(lambda s: s > 0, output_size)): raise ValueError("convolution inp is too small (output would be {})".format( 'x'.join(map(str, output_size)))) @@ -57,17 +64,22 @@ def forward(ctx, inp, weight, bias, stride, padding, dilation, groups): ctx.dilation = _pair(dilation) ctx.groups = groups kH, kW = weight.size(2), weight.size(3) - + output_size = _output_size(inp, weight, padding, dilation, stride) + output = inp.new(*output_size) + columns = inp.new(*output_size) ones = inp.new(*output_size) backend = type2backend[inp.type()] f = getattr(backend, 'SpatialConvolutionMM_updateOutput') + + #order as stated in + # https://github.com/torch/nn/blob/master/lib/THNN/generic/SpatialConvolutionMM.c f(backend.library_state, inp, output, weight, bias, columns, ones, - kH, kW, ctx.stride[0], ctx.stride[1], ctx.padding[0], ctx.padding[1]) - + kW, kH, ctx.stride[1], ctx.stride[0], ctx.padding[1], ctx.padding[0]) + return output @staticmethod @@ -75,6 +87,7 @@ def backward(ctx, grad_output): inp, weight, bias = ctx.saved_tensors stride, padding, dilation, groups = ctx.stride, ctx.padding, ctx.dilation, ctx.groups output_size = _output_size(inp, weight, padding, dilation, stride) + kH, kW = weight.size(2), weight.size(3) wplus = weight.clone().clamp(min=0) @@ -85,7 +98,8 @@ def backward(ctx, grad_output): backend = type2backend[inp.type()] f = getattr(backend, 'SpatialConvolutionMM_updateOutput') f(backend.library_state, inp, new_output, wplus, None, columns, ones, - kH, kW, ctx.stride[0], ctx.stride[1], ctx.padding[0], ctx.padding[1]) + kW, kH, ctx.stride[1], ctx.stride[0], ctx.padding[1], ctx.padding[0]) + normalized_grad_output = grad_output.data / (new_output + 1e-10) normalized_grad_output = normalized_grad_output * (new_output > 0).float() @@ -95,7 +109,7 @@ def backward(ctx, grad_output): g = getattr(backend, 'SpatialConvolutionMM_updateGradInput') g(backend.library_state, inp, normalized_grad_output, grad_inp, wplus, columns, ones, - kH, kW, ctx.stride[0], ctx.stride[1], ctx.padding[0], ctx.padding[1]) + kW, kH, ctx.stride[1], ctx.stride[0], ctx.padding[1], ctx.padding[0]) grad_inp = grad_inp * inp diff --git a/utils.py b/utils.py index 84cfaee..efbe7b6 100644 --- a/utils.py +++ b/utils.py @@ -31,5 +31,5 @@ def upsample(inp, size): backend = type2backend[inp.type()] f = getattr(backend, 'SpatialUpSamplingBilinear_updateOutput') upsample_inp = inp.new() - f(backend.library_state, inp, upsample_inp, size[0], size[1]) - return upsample_inp \ No newline at end of file + f(backend.library_state, inp, upsample_inp, size[0], size[1],False) + return upsample_inp