Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a reflectance padding option from leongatys/NeuralImageSynthesis #377

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
19 changes: 18 additions & 1 deletion neural_style.lua
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ cmd:option('-seed', -1)
cmd:option('-content_layers', 'relu4_2', 'layers for content')
cmd:option('-style_layers', 'relu1_1,relu2_1,relu3_1,relu4_1,relu5_1', 'layers for style')

cmd:option('-padding', 'default', 'default|reflect')

local function main(params)
local dtype, multigpu = setup_gpu(params)
Expand Down Expand Up @@ -117,6 +118,22 @@ local function main(params)
local layer = cnn:get(i)
local name = layer.name
local layer_type = torch.type(layer)
--reflectance padding option from leongatys/NeuralImageSynthesis
local is_convolution = (layer_type == 'cudnn.SpatialConvolution' or layer_type == 'nn.SpatialConvolution')
if is_convolution and params.padding ~= 'default' then
local padW, padH = layer.padW, layer.padH
if params.padding == 'reflect' then
local pad_layer = nn.SpatialReflectionPadding(padW, padW, padH, padH):type(dtype)
net:add(pad_layer)
elseif params.padding == 'replicate' then
local pad_layer = nn.SpatialReplicationPadding(padW, padW, padH, padH):type(dtype)
net:add(pad_layer)
else
error('Unknown padding type')
end
layer.padW = 0
layer.padH = 0
end
local is_pooling = (layer_type == 'cudnn.SpatialMaxPooling' or layer_type == 'nn.SpatialMaxPooling')
if is_pooling and params.pooling == 'avg' then
assert(layer.padW == 0 and layer.padH == 0)
Expand All @@ -128,7 +145,7 @@ local function main(params)
net:add(avg_pool_layer)
else
net:add(layer)
end
end
if name == content_layers[next_content_idx] then
print("Setting up content layer", i, ":", layer.name)
local norm = params.normalize_gradients
Expand Down