Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for projective warp via full homogeneous coordinates. #11

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ This is a **Tensorflow** implementation of [Spatial Transformer Networks](https:

The STN is composed of 3 elements.

- **localization network**: takes the feature map as input and outputs the parameters of the affine transformation that should be applied to that feature map.
- **localization network**: takes the feature map as input and outputs the parameters of the transformation that should be applied to that feature map.

- **grid generator:** generates a grid of (x,y) coordinates using the parameters of the affine transformation that correspond to a set of points where the input feature map should be sampled to produce the transformed output feature map.
- **grid generator:** generates a grid of (x,y) coordinates using the parameters of the transformation that correspond to a set of points where the input feature map should be sampled to produce the transformed output feature map.

- **bilinear sampler:** takes as input the input feature map and the grid generated by the grid generator and produces the output feature map using bilinear interpolation.

Expand All @@ -36,13 +36,20 @@ It can be constrained to one of *attention* by writing it in the form

where the parameters `s`, `t_x` and `t_y` can be regressed to allow cropping, translation, and isotropic scaling.

A more general projective transformation can also be specified through the transformation matrix B

<p align="center">
<img src="./img/projective.png" width="175px">
</p>

## Explore

Run the [Sanity Check](https://github.com/kevinzakka/spatial-transformer-network/blob/master/Sanity%20Check.ipynb) to get a feel of how the spatial transformer can be plugged into any existing code. For example, here's the result of a 45 degree rotation:
Run the [Sanity Check](https://github.com/kevinzakka/spatial-transformer-network/blob/master/Sanity%20Check.ipynb) to get a feel of how the spatial transformer can be plugged into any existing code. For example, here's the result of a 45 degree rotation and a separate projective warp:

<p align="center">
<img src="./img/b4.png" alt="Drawing" width="40%">
<img src="./img/after.png" alt="Drawing" width="40%">
<img src="./img/b4.png" alt="Drawing" width="30%">
<img src="./img/after.png" alt="Drawing" width="30%">
<img src="./img/after_2.png" alt="Drawing" width="30%">
</p>

## API
Expand All @@ -56,12 +63,12 @@ out = spatial_transformer_network(input_feature_map, theta, out_dims)
**Parameters**

- `input_feature_map`: the output of the layer preceding the localization network. If the STN layer is the first layer of the network, then this corresponds to the input images. Shape should be (B, H, W, C).
- `theta`: this is the output of the localization network. Shape should be (B, 6)
- `theta`: this is the output of the localization network. Shape should be (B, X) where X is the number of free parameters (usually 6 or 8).
- `out_dims`: desired (H, W) of the output feature map. Useful for upsampling or downsampling. If not specified, then output dimensions will be equal to `input_feature_map` dimensions.

**Note**

You must define a localization network right before using this layer. The localization network is usually a ConvNet or a FC-net that has 6 output nodes (the 6 parameters of the affine transformation).
You must define a localization network right before using this layer. The localization network is usually a ConvNet or a FC-net that has 6 output nodes (the 6 parameters of an affine transformation).

You need to initialize the localization network to the identity transform before starting the training process. Here's a small sample code for illustration purposes.

Expand Down
13 changes: 8 additions & 5 deletions Sanity Check.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@
"theta = np.array([[np.cos(angleRad), -np.sin(angleRad), 0], [np.sin(angleRad), np.cos(angleRad), 0]])\n",
"\n",
"# # identity transform\n",
"# theta = np.array([[1., 0, 0], [0, 1., 0]])"
"# theta = np.array([[1., 0, 0], [0, 1., 0]])\n",
"\n",
"# # perspective transform\n",
"# theta = np.array([[1, 0, 0], [0, 1, 0], [0.6, 0, 1]])"
]
},
{
Expand All @@ -134,13 +137,13 @@
"# create localisation network and convolutional layer\n",
"with tf.variable_scope('spatial_transformer_0'):\n",
"\n",
" # create a fully-connected layer with 6 output nodes\n",
" n_fc = 6\n",
" W_fc1 = tf.Variable(tf.zeros([H*W*C, n_fc]), name='W_fc1')\n",
"\n",
" # affine transformation\n",
" theta = theta.astype('float32')\n",
" theta = theta.flatten()\n",
" \n",
" # create a fully-connected layer with the correct number of output nodes\n",
" n_fc = theta.shape[0]\n",
" W_fc1 = tf.Variable(tf.zeros([H*W*C, n_fc]), name='W_fc1')\n",
"\n",
" b_fc1 = tf.Variable(initial_value=theta, name='b_fc1')\n",
" h_fc1 = tf.matmul(tf.zeros([B, H*W*C]), W_fc1) + b_fc1\n",
Expand Down
Binary file added img/after_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/projective.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
59 changes: 45 additions & 14 deletions transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ def spatial_transformer_network(input_fmap, theta, out_dims=None, **kwargs):
transformer layer is at the beginning of architecture. Should be
a tensor of shape (B, H, W, C).

- theta: affine transform tensor of shape (B, 6). Permits cropping,
translation and isotropic scaling. Initialize to identity matrix.
It is the output of the localization network.
- theta: transform tensor of shape (B, X) where X <= 9. Permits
cropping, translation, isotropic scaling and projective transformation.
Initialize to identity matrix. It is the output of the localization
network.

Returns
-------
Expand All @@ -44,25 +45,51 @@ def spatial_transformer_network(input_fmap, theta, out_dims=None, **kwargs):
W = tf.shape(input_fmap)[2]
C = tf.shape(input_fmap)[3]

# reshape theta to (B, 2, 3)
theta = tf.reshape(theta, [B, 2, 3])
# pad theta to a 3x3 transform matrix
theta = pad_theta(theta)
# reshape theta to (B, 3, 3)
theta = tf.reshape(theta, [B, 3, 3])

# generate grids of same size or upsample/downsample if specified
if out_dims:
out_H = out_dims[0]
out_W = out_dims[1]
batch_grids = affine_grid_generator(out_H, out_W, theta)
x_s, y_s = affine_grid_generator(out_H, out_W, theta)
else:
batch_grids = affine_grid_generator(H, W, theta)

x_s = batch_grids[:, 0, :, :]
y_s = batch_grids[:, 1, :, :]
x_s, y_s = affine_grid_generator(H, W, theta)

# sample input with grid to get output
out_fmap = bilinear_sampler(input_fmap, x_s, y_s)

return out_fmap

def pad_theta(theta):
"""
Utility function to pad input theta to a 3x3 transformation matrix
using the 3x3 identity matrix.

Input
-----
- theta: tensor of shape (B, X) where X <= 9

Returns
-------
- theta_padded: input theta padded (if needed) to a 3x3 transform
matrix of shape (B, 9)
"""
B = tf.shape(theta)[0]
theta_params = tf.shape(theta)[1]

assertion = tf.Assert(theta_params <= 9, [theta_params])

with tf.control_dependencies([assertion]):
identity_flat = tf.reshape(tf.eye(3), [3*3])
identity_remaining = identity_flat[theta_params:]
identity_batch = tf.reshape(tf.tile(identity_remaining, [B]), [B, 9-theta_params])
theta_padded = tf.concat([theta, identity_batch], axis=1)

return theta_padded

def get_pixel_value(img, x, y):
"""
Utility function to get pixel value for coordinate
Expand Down Expand Up @@ -148,12 +175,16 @@ def affine_grid_generator(height, width, theta):

# transform the sampling grid - batch multiply
batch_grids = tf.matmul(theta, sampling_grid)
# batch grid has shape (num_batch, 2, H*W)
# batch grid has shape (num_batch, 3, H*W)

# reshape to (num_batch, H, W, 3)
batch_grids = tf.reshape(batch_grids, [num_batch, 3, height, width])

# reshape to (num_batch, H, W, 2)
batch_grids = tf.reshape(batch_grids, [num_batch, 2, height, width])
# homogeneous -> 2D (divide by w)
x_s = batch_grids[:, 0, :, :] / batch_grids[:, 2, :, :]
y_s = batch_grids[:, 1, :, :] / batch_grids[:, 2, :, :]

return batch_grids
return x_s, y_s

def bilinear_sampler(img, x, y):
"""
Expand Down