Skip to content

Commit

Permalink
Dilation is available
Browse files Browse the repository at this point in the history
  • Loading branch information
David Nilsson committed Jun 13, 2018
1 parent af98368 commit 518d62b
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 11 deletions.
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,29 @@ Check config.py. Download all data from the cityscapes dataset and change the pa

Run misc/compile.sh to compile the bilinear warping operator. Change the include directory on line 9 if you get errors related to libcudart.

Download all pretrained models from [here](https://drive.google.com/open?id=1eGy7JcX1ptzxwQ6thEd2R_ix4VehLRQL) and unpack them under ./models. For instance, the file ./models/flownet1.index should exist.
Download all pretrained models from [here](https://drive.google.com/open?id=1eGy7JcX1ptzxwQ6thEd2R_ix4VehLRQL) and unpack them under ./checkpoints/. For instance, the file ./checkpoints/flownet1.index should exist.

### Usage

Evaluate the GRFP(LRR-4x, FlowNet2) setup on the validation set by running
Evaluate the GRFP(LRR-4x, FlowNet2) setup on the validation set by running:
```
python evaluate.py --static lrr --flow flownet2
```

Evaluation using PSP and Dilation10 as well as code for training will be added soon.
Evalutate GRFP(Dilation10, FlowNet2) for various number of frames, as in Table 3 and 4 in the paper:
```
python evaluate.py --static dilation --flow flownet2 --frames 1
python evaluate.py --static dilation --flow flownet2 --frames 5
```

The values in table 9 can be reproduced by running the following. It takes about 4 hours on a titan X GPU.
The values in table 9 can be reproduced by running the following:
```
python evaluate.py --static lrr --flow flownet2
python evaluate.py --static lrr --flow flownet1
python evaluate.py --static lrr --flow farneback
```

Evaluation using PSP and code for training will be added soon.

### Citation
If you use the code in your own research, please cite
Expand Down
31 changes: 24 additions & 7 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import config as cfg
from models.stgru import STGRU
from models.lrr import LRR
from models.dilation import dilation10network
from models.flownet2 import Flownet2
from models.flownet1 import Flownet1
from tensorflow.python.framework import ops
Expand All @@ -26,7 +27,7 @@ def evaluate(args):
cs_id2trainid, cs_id2name = pickle.load(f)
f.close()

assert args.static == 'lrr', "Only LRR is supported for now."
assert args.static in ['dilation', 'lrr'], "Only dilation and LRR are supported for now."

if args.flow == 'flownet2':
with tf.variable_scope('flow'):
Expand All @@ -47,17 +48,25 @@ def evaluate(args):
input_segmentation, prev_h, new_h, \
prediction = RNN.get_one_step_predictor()

static_input = tf.placeholder(tf.float32)
static_network = LRR()
static_output = static_network(static_input)
if args.static == 'lrr':
static_input = tf.placeholder(tf.float32)
static_network = LRR()
static_output = static_network(static_input)
elif args.static == 'dilation':
static_input = tf.placeholder(tf.float32)
static_network = dilation10network()
static_output = static_network.get_output_tensor(static_input, im_size)

saver = tf.train.Saver([k for k in tf.global_variables() if not k.name.startswith('flow/')])
if args.flow in ['flownet1', 'flownet2']:
saver_fn = tf.train.Saver([k for k in tf.global_variables() if k.name.startswith('flow/')])

with tf.Session() as sess:
saver.restore(sess, './checkpoints/lrr_grfp')

if args.static == 'lrr':
saver.restore(sess, './checkpoints/lrr_grfp')
elif args.static == 'dilation':
saver.restore(sess, './checkpoints/dilation_grfp')

if args.flow == 'flownet1':
saver_fn.restore(sess, './checkpoints/flownet1')
elif args.flow == 'flownet2':
Expand Down Expand Up @@ -93,7 +102,15 @@ def evaluate(args):
flow = flow[np.newaxis,...]

# Static segmentation
x = sess.run(static_output, feed_dict={static_input: im})
if args.static == 'dilation':
# augment a 186x186 border around the image and subtract the mean
im_aug = cv2.copyMakeBorder(im[0], 186, 186, 186, 186, cv2.BORDER_REFLECT_101)
im_aug = im_aug - image_mean
im_aug = im_aug[np.newaxis,...]

x = sess.run(static_output, feed_dict={static_input: im_aug})
elif args.static == 'lrr':
x = sess.run(static_output, feed_dict={static_input: im})

if first_frame:
# the hidden state is simple the static segmentation for the first frame
Expand Down
139 changes: 139 additions & 0 deletions models/dilation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import tensorflow as tf

class dilation10network:
def __init__(self, dropout_keeprate = 1.0):
#
self.dropout_keeprate = dropout_keeprate
self.mean = [72.39,82.91,73.16]

self.weights = {
'conv1_1': tf.Variable(tf.zeros([3, 3, 3, 64], dtype=tf.float32), name='conv1_1'),
'conv1_2': tf.Variable(tf.zeros([3, 3, 64, 64], dtype=tf.float32), name='conv1_2'),

'conv2_1': tf.Variable(tf.zeros([3, 3, 64, 128], dtype=tf.float32), name='conv2_1'),
'conv2_2': tf.Variable(tf.zeros([3, 3, 128, 128], dtype=tf.float32), name='conv2_2'),

'conv3_1': tf.Variable(tf.zeros([3, 3, 128, 256], dtype=tf.float32), name='conv3_1'),
'conv3_2': tf.Variable(tf.zeros([3, 3, 256, 256], dtype=tf.float32), name='conv3_2'),
'conv3_3': tf.Variable(tf.zeros([3, 3, 256, 256], dtype=tf.float32), name='conv3_3'),

'conv4_1': tf.Variable(tf.zeros([3, 3, 256, 512], dtype=tf.float32), name='conv4_1'),
'conv4_2': tf.Variable(tf.zeros([3, 3, 512, 512], dtype=tf.float32), name='conv4_2'),
'conv4_3': tf.Variable(tf.zeros([3, 3, 512, 512], dtype=tf.float32), name='conv4_3'),

'conv5_1': tf.Variable(tf.zeros([3, 3, 512, 512], dtype=tf.float32), name='conv5_1'),
'conv5_2': tf.Variable(tf.zeros([3, 3, 512, 512], dtype=tf.float32), name='conv5_2'),
'conv5_3': tf.Variable(tf.zeros([3, 3, 512, 512], dtype=tf.float32), name='conv5_3'),

'fc6': tf.Variable(tf.zeros([7, 7, 512, 4096], dtype=tf.float32), name='fc6'),
'fc7': tf.Variable(tf.zeros([1, 1, 4096, 4096], dtype=tf.float32), name='fc7'),
'final': tf.Variable(tf.zeros([1, 1, 4096, 19], dtype=tf.float32), name='final'),

'ctx_conv1_1': tf.Variable(tf.zeros([3, 3, 19, 19], dtype=tf.float32), name='ctx_conv1_1'),
'ctx_conv1_2': tf.Variable(tf.zeros([3, 3, 19, 19], dtype=tf.float32), name='ctx_conv1_1'),
'ctx_conv2_1': tf.Variable(tf.zeros([3, 3, 19, 19], dtype=tf.float32), name='ctx_conv1_1'),
'ctx_conv3_1': tf.Variable(tf.zeros([3, 3, 19, 19], dtype=tf.float32), name='ctx_conv1_1'),
'ctx_conv4_1': tf.Variable(tf.zeros([3, 3, 19, 19], dtype=tf.float32), name='ctx_conv1_1'),
'ctx_conv5_1': tf.Variable(tf.zeros([3, 3, 19, 19], dtype=tf.float32), name='ctx_conv1_1'),
'ctx_conv6_1': tf.Variable(tf.zeros([3, 3, 19, 19], dtype=tf.float32), name='ctx_conv1_1'),
'ctx_conv7_1': tf.Variable(tf.zeros([3, 3, 19, 19], dtype=tf.float32), name='ctx_conv1_1'),
'ctx_fc1': tf.Variable(tf.zeros([3, 3, 19, 19], dtype=tf.float32), name='ctx_conv1_1'),
'ctx_final': tf.Variable(tf.zeros([1, 1, 19, 19], dtype=tf.float32), name='ctx_conv1_1'),
'ctx_upsample': tf.Variable(tf.zeros([16, 16, 19, 19], dtype=tf.float32), name='ctx_conv1_1'),
}
self.biases = {
'conv1_1': tf.Variable(tf.zeros([64], dtype=tf.float32), name='conv1_1_b'),
'conv1_2': tf.Variable(tf.zeros([64], dtype=tf.float32), name='conv1_2_b'),

'conv2_1': tf.Variable(tf.zeros([128], dtype=tf.float32), name='conv2_1_b'),
'conv2_2': tf.Variable(tf.zeros([128], dtype=tf.float32), name='conv2_2_b'),

'conv3_1': tf.Variable(tf.zeros([256], dtype=tf.float32), name='conv3_1_b'),
'conv3_2': tf.Variable(tf.zeros([256], dtype=tf.float32), name='conv3_2_b'),
'conv3_3': tf.Variable(tf.zeros([256], dtype=tf.float32), name='conv3_3_b'),

'conv4_1': tf.Variable(tf.zeros([512], dtype=tf.float32), name='conv4_1_b'),
'conv4_2': tf.Variable(tf.zeros([512], dtype=tf.float32), name='conv4_2_b'),
'conv4_3': tf.Variable(tf.zeros([512], dtype=tf.float32), name='conv4_3_b'),

'conv5_1': tf.Variable(tf.zeros([512], dtype=tf.float32), name='conv5_1_b'),
'conv5_2': tf.Variable(tf.zeros([512], dtype=tf.float32), name='conv5_2_b'),
'conv5_3': tf.Variable(tf.zeros([512], dtype=tf.float32), name='conv5_3_b'),

'fc6': tf.Variable(tf.zeros([4096], dtype=tf.float32), name='fc6_b'),
'fc7': tf.Variable(tf.zeros([4096], dtype=tf.float32), name='fc7_b'),
'final': tf.Variable(tf.zeros([19], dtype=tf.float32), name='final_b'),

'ctx_conv1_1': tf.Variable(tf.zeros([19], dtype=tf.float32), name='ctx_conv1_1_b'),
'ctx_conv1_2': tf.Variable(tf.zeros([19], dtype=tf.float32), name='ctx_conv1_1_b'),
'ctx_conv2_1': tf.Variable(tf.zeros([19], dtype=tf.float32), name='ctx_conv1_1_b'),
'ctx_conv3_1': tf.Variable(tf.zeros([19], dtype=tf.float32), name='ctx_conv1_1_b'),
'ctx_conv4_1': tf.Variable(tf.zeros([19], dtype=tf.float32), name='ctx_conv1_1_b'),
'ctx_conv5_1': tf.Variable(tf.zeros([19], dtype=tf.float32), name='ctx_conv1_1_b'),
'ctx_conv6_1': tf.Variable(tf.zeros([19], dtype=tf.float32), name='ctx_conv1_1_b'),
'ctx_conv7_1': tf.Variable(tf.zeros([19], dtype=tf.float32), name='ctx_conv1_1_b'),
'ctx_fc1': tf.Variable(tf.zeros([19], dtype=tf.float32), name='ctx_conv1_1_b'),
'ctx_final': tf.Variable(tf.zeros([19], dtype=tf.float32), name='ctx_conv1_1_b'),
}

def get_output_tensor(self, x, out_size):
# returns output tensor
output_shape = [1, out_size[0], out_size[1], 19]

conv1_1 = tf.nn.relu(tf.nn.conv2d(x, self.weights['conv1_1'], strides=[1,1,1,1], padding="VALID") + self.biases['conv1_1'])
conv1_2 = tf.nn.relu(tf.nn.conv2d(conv1_1, self.weights['conv1_2'], strides=[1,1,1,1], padding="VALID") + self.biases['conv1_2'])
conv1_2 = tf.nn.max_pool(conv1_2, [1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")

conv2_1 = tf.nn.relu(tf.nn.conv2d(conv1_2, self.weights['conv2_1'], strides=[1,1,1,1], padding="VALID") + self.biases['conv2_1'])
conv2_2 = tf.nn.relu(tf.nn.conv2d(conv2_1, self.weights['conv2_2'], strides=[1,1,1,1], padding="VALID") + self.biases['conv2_2'])
conv2_2 = tf.nn.max_pool(conv2_2, [1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")

conv3_1 = tf.nn.relu(tf.nn.conv2d(conv2_2, self.weights['conv3_1'], strides=[1,1,1,1], padding="VALID") + self.biases['conv3_1'])
conv3_2 = tf.nn.relu(tf.nn.conv2d(conv3_1, self.weights['conv3_2'], strides=[1,1,1,1], padding="VALID") + self.biases['conv3_2'])
conv3_3 = tf.nn.relu(tf.nn.conv2d(conv3_2, self.weights['conv3_3'], strides=[1,1,1,1], padding="VALID") + self.biases['conv3_3'])
conv3_3 = tf.nn.max_pool(conv3_3, [1, 2, 2, 1], strides=[1, 2, 2, 1], padding="VALID")

conv4_1 = tf.nn.relu(tf.nn.conv2d(conv3_3, self.weights['conv4_1'], strides=[1,1,1,1], padding="VALID") + self.biases['conv4_1'])
conv4_2 = tf.nn.relu(tf.nn.conv2d(conv4_1, self.weights['conv4_2'], strides=[1,1,1,1], padding="VALID") + self.biases['conv4_2'])
conv4_3 = tf.nn.relu(tf.nn.conv2d(conv4_2, self.weights['conv4_3'], strides=[1,1,1,1], padding="VALID") + self.biases['conv4_3'])
# not pooling, instead dilations in the folling ops

conv5_1 = tf.nn.relu(tf.nn.atrous_conv2d(conv4_3, self.weights['conv5_1'], padding="VALID", rate=2) + self.biases['conv5_1'])
conv5_2 = tf.nn.relu(tf.nn.atrous_conv2d(conv5_1, self.weights['conv5_2'], padding="VALID", rate=2) + self.biases['conv5_2'])
conv5_3 = tf.nn.relu(tf.nn.atrous_conv2d(conv5_2, self.weights['conv5_3'], padding="VALID", rate=2) + self.biases['conv5_3'])

fc6 = tf.nn.relu(tf.nn.atrous_conv2d(conv5_3, self.weights['fc6'], padding="VALID", rate=4) + self.biases['fc6'])
fc6 = tf.nn.dropout(fc6, self.dropout_keeprate)
fc7 = tf.nn.relu(tf.nn.atrous_conv2d(fc6, self.weights['fc7'], padding="VALID", rate=4) + self.biases['fc7'])
fc7 = tf.nn.dropout(fc7, self.dropout_keeprate)
final = tf.nn.atrous_conv2d(fc7, self.weights['final'], padding="VALID", rate=4) + self.biases['final']

ctx_conv1_1 = tf.nn.relu(tf.nn.conv2d(final, self.weights['ctx_conv1_1'], strides=[1,1,1,1], padding="SAME") + self.biases['ctx_conv1_1'])
ctx_conv1_2 = tf.nn.relu(tf.nn.conv2d(ctx_conv1_1, self.weights['ctx_conv1_2'], strides=[1,1,1,1], padding="SAME") + self.biases['ctx_conv1_2'])
ctx_conv2_1 = tf.nn.relu(tf.nn.atrous_conv2d(ctx_conv1_2, self.weights['ctx_conv2_1'], padding="SAME", rate=2) + self.biases['ctx_conv2_1'])
ctx_conv3_1 = tf.nn.relu(tf.nn.atrous_conv2d(ctx_conv2_1, self.weights['ctx_conv3_1'], padding="SAME", rate=4) + self.biases['ctx_conv3_1'])
ctx_conv4_1 = tf.nn.relu(tf.nn.atrous_conv2d(ctx_conv3_1, self.weights['ctx_conv4_1'], padding="SAME", rate=8) + self.biases['ctx_conv4_1'])
ctx_conv5_1 = tf.nn.relu(tf.nn.atrous_conv2d(ctx_conv4_1, self.weights['ctx_conv5_1'], padding="SAME", rate=16) + self.biases['ctx_conv5_1'])
ctx_conv6_1 = tf.nn.relu(tf.nn.atrous_conv2d(ctx_conv5_1, self.weights['ctx_conv6_1'], padding="SAME", rate=32) + self.biases['ctx_conv6_1'])
ctx_conv7_1 = tf.nn.relu(tf.nn.atrous_conv2d(ctx_conv6_1, self.weights['ctx_conv7_1'], padding="SAME", rate=64) + self.biases['ctx_conv7_1'])

ctx_fc1 = tf.nn.relu(tf.nn.conv2d(ctx_conv7_1, self.weights['ctx_fc1'], strides=[1,1,1,1], padding="SAME") + self.biases['ctx_fc1'])
ctx_final = tf.nn.conv2d(ctx_fc1, self.weights['ctx_final'], strides=[1,1,1,1], padding="SAME") + self.biases['ctx_final']
ctx_upsample = tf.nn.conv2d_transpose(ctx_final, self.weights['ctx_upsample'], output_shape=output_shape, strides=[1,8,8,1])

return ctx_upsample

def get_optimizer(self, x, y, learning_rate):
# optimize wrt the ctx_* variables
dLdy = tf.placeholder('float')

# the correct values will backpropagate to ctx_upsample
loss = tf.reduce_sum(dLdy * y)

#opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.95, beta2=0.99, epsilon=1e-8)
opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
opt = opt.minimize(loss,
[v for k, v in self.weights.iteritems() if (k[0:4] == 'ctx_' or k in ['fc6', 'fc7', 'final'])].extend(
[v for k, v in self.biases.iteritems() if (k[0:4] == 'ctx_' or k in ['fc6', 'fc7', 'final'])]))

return opt, dLdy

0 comments on commit 518d62b

Please sign in to comment.