diff --git a/custom_model.py b/custom_model.py index 451d599..b52b675 100644 --- a/custom_model.py +++ b/custom_model.py @@ -1,4 +1,5 @@ from chainer import functions as F +from chainer.functions.pooling.average_pooling_2d import average_pooling_2d from chainer import links as L from chainer import Chain from resnet_group_norm import ResNet as ResNetGroupNorm @@ -8,8 +9,8 @@ class CustomModel(Chain): def __init__(self, n_actions): super(CustomModel, self).__init__() with self.init_scope(): - self.resNet=L.ResNet50Layers() - self.l1=L.Linear(2138, 1024) + self.resNet=ResNetGroupNorm(n_layers=18) + self.l1=L.Linear(602, 1024) self.l2=L.Linear(1024, 1024) self.l3=L.Linear(1024, n_actions) @@ -17,8 +18,15 @@ def forward(self, x): image, history = x[0], x[1] image = F.reshape(image, (-1,3,224,224)) history = F.reshape(history.astype('float32'),(-1,90)) - h1 = F.relu(self.resNet(image, layers=['pool5'])['pool5']) - h1 = F.reshape(F.concat((h1, history), axis=1), (-1,2138)) + h1 = self.resNet(image) + + # pooling as done here: https://github.com/chainer/chainer/blob/v6.0.0/chainer/links/model/vision/resnet.py#L655 + n, channel, rows, cols = h1.shape + h1 = average_pooling_2d(h1, (rows, cols), stride=1) + h1 = F.reshape(h1, (n, channel)) + + h1 = F.relu(h1) + h1 = F.reshape(F.concat((h1, history), axis=1), (-1,602)) h2 = F.relu(self.l1(h1)) h3 = F.relu(self.l2(h2)) return F.relu(self.l3(h3)) diff --git a/resnet_group_norm.py b/resnet_group_norm.py index 8fc0c89..9021510 100644 --- a/resnet_group_norm.py +++ b/resnet_group_norm.py @@ -47,11 +47,11 @@ def __init__(self, n_layers, class_labels=None): self.res4 = BasicBlock(block[2], 512) elif n_layers in [18, 20, 21, 34]: self.conv1 = L.Convolution2D(3, 64, 7, 2, 3, initialW=w, nobias=True) - self.bn1 = L.GroupNormalization(16) - self.res2 = BasicBlock(block[0], 64, 1, num_groups=16) - self.res3 = BasicBlock(block[1], 128) - self.res4 = BasicBlock(block[2], 256) - self.res5 = BasicBlock(block[3], 512) + self.bn1 = L.GroupNormalization(16, 64) + self.res2 = BasicBlock(block[0], 64, 64, 1, num_groups=16) + self.res3 = BasicBlock(block[1], 64, 128) + self.res4 = BasicBlock(block[2], 128, 256) + self.res5 = BasicBlock(block[3], 256, 512) elif n_layers in [32, 44, 56, 110]: self.conv1 = L.Convolution2D(3, 16, 7, 2, 3, initialW=w, nobias=True) self.bn1 = L.GroupNormalization(8) @@ -98,12 +98,12 @@ def __call__(self, x): class BasicBlock(chainer.ChainList): - def __init__(self, layer, ch, stride=2, num_groups=32): + def __init__(self, layer, input_ch, output_ch, stride=2, num_groups=32): super(BasicBlock, self).__init__() with self.init_scope(): - self.add_link(BasicA(ch, stride, num_groups)) + self.add_link(BasicA(input_ch, output_ch, stride, num_groups)) for i in range(layer - 1): - self.add_link(BasicB(ch, num_groups)) + self.add_link(BasicB(output_ch, num_groups)) def __call__(self, x): for f in self.children(): @@ -127,18 +127,18 @@ def __call__(self, x): class BasicA(chainer.Chain): - def __init__(self, ch, stride, num_groups): + def __init__(self, input_ch, output_ch, stride, num_groups): super(BasicA, self).__init__() w = chainer.initializers.HeNormal() with self.init_scope(): - self.conv1 = L.Convolution2D(None, ch, 3, stride, 1, initialW=w, nobias=True) - self.bn1 = L.GroupNormalization(num_groups) - self.conv2 = L.Convolution2D(None, ch, 3, 1, 1, initialW=w, nobias=True) - self.bn2 = L.GroupNormalization(num_groups) + self.conv1 = L.Convolution2D(input_ch, output_ch, 3, stride, 1, initialW=w, nobias=True) + self.bn1 = L.GroupNormalization(num_groups, output_ch) + self.conv2 = L.Convolution2D(output_ch, output_ch, 3, 1, 1, initialW=w, nobias=True) + self.bn2 = L.GroupNormalization(num_groups, output_ch) - self.conv3 = L.Convolution2D(None, ch, 3, stride, 1, initialW=w, nobias=True) - self.bn3 = L.GroupNormalization(num_groups) + self.conv3 = L.Convolution2D(input_ch, output_ch, 3, stride, 1, initialW=w, nobias=True) + self.bn3 = L.GroupNormalization(num_groups, output_ch) def __call__(self, x): h1 = F.relu(self.bn1(self.conv1(x))) @@ -155,10 +155,10 @@ def __init__(self, ch, num_groups): w = chainer.initializers.HeNormal() with self.init_scope(): - self.conv1 = L.Convolution2D(None, ch, 3, 1, 1, initialW=w, nobias=True) - self.bn1 = L.GroupNormalization(num_groups) - self.conv2 = L.Convolution2D(None, ch, 3, 1, 1, initialW=w, nobias=True) - self.bn2 = L.GroupNormalization(num_groups) + self.conv1 = L.Convolution2D(ch, ch, 3, 1, 1, initialW=w, nobias=True) + self.bn1 = L.GroupNormalization(num_groups, ch) + self.conv2 = L.Convolution2D(ch, ch, 3, 1, 1, initialW=w, nobias=True) + self.bn2 = L.GroupNormalization(num_groups, ch) def __call__(self, x): h = F.relu(self.bn1(self.conv1(x))) @@ -220,4 +220,4 @@ def __call__(self, x): h = F.relu(self.bn2(self.conv2(h))) h = self.bn3(self.conv3(h)) - return F.relu(h + x) \ No newline at end of file + return F.relu(h + x)