Skip to content

Commit

Permalink
Added color transfer to script_helper
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed Nov 16, 2016
1 parent c46978b commit 942c0ea
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 9 deletions.
Binary file modified script_helper/Neural Style Transfer.exe
Binary file not shown.
Binary file modified script_helper/Neural Style Transfer.pdb
Binary file not shown.
46 changes: 37 additions & 9 deletions script_helper/Script/color_transfer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,43 @@
import argparse
import os
import numpy as np
from scipy.interpolate import interp1d
from scipy.misc import imread, imresize, imsave, fromimage, toimage


# Util function to match histograms
def match_histograms(A, B, rng=(0.0, 255.0), bins=64):
(Ha, Xa), (Hb, Xb) = [np.histogram(i, bins=bins, range=rng, density=True) for i in [A, B]]
X = np.linspace(rng[0], rng[1], bins, endpoint=True)
Hpa, Hpb = [np.cumsum(i) * (rng[1] - rng[0]) ** 2 / float(bins) for i in [Ha, Hb]]
inv_Ha = interp1d(X, Hpa, bounds_error=False)
map_Hb = interp1d(Hpb, X, bounds_error=False)
return map_Hb(inv_Ha(A).clip(0.0, 255.0))


# util function to preserve image color
def original_color_transform(content, generated, mask=None):
generated = fromimage(toimage(generated, mode='RGB'), mode='YCbCr') # Convert to YCbCr color space
def original_color_transform(content, generated, mask=None, hist_match=0, mode='YCbCr'):
generated = fromimage(toimage(generated, mode='RGB'), mode=mode) # Convert to YCbCr color space

if mask is None:
generated[:, :, 1:] = content[:, :, 1:] # Generated CbCr = Content CbCr
if hist_match == 0:
for channel in range(3):
generated[:, :, channel] = match_histograms(generated[:, :, channel], content[:, :, channel])
else:
generated[:, :, 1:] = content[:, :, 1:]
else:
width, height, channels = generated.shape

for i in range(width):
for j in range(height):
if mask[i, j] == 1:
generated[i, j, 1:] = content[i, j, 1:]
if hist_match == 0:
for channel in range(3):
generated[i, j, channel] = match_histograms(generated[i, j, channel], content[i, j, channel])
else:
generated[i, j, 1:] = content[i, j, 1:]

generated = fromimage(toimage(generated, mode='YCbCr'), mode='RGB') # Convert to RGB color space
generated = fromimage(toimage(generated, mode=mode), mode='RGB') # Convert to RGB color space
return generated


Expand All @@ -41,16 +61,24 @@ def load_mask(mask_path, shape):

parser.add_argument('content_image', type=str, help='Path to content image')
parser.add_argument('generated_image', type=str, help='Path to generated image')
parser.add_argument("--mask", default=None, type=str, help='Path to mask image')
parser.add_argument('--mask', default=None, type=str, help='Path to mask image')
parser.add_argument('--hist_match', type=int, default=0, help='Perform histogram matching for color matching')

args = parser.parse_args()

image_path = os.path.splitext(args.generated_image)[0] + "_original_color.png"
if args.hist_match == 0:
image_suffix = "_histogram_color.png"
mode = "RGB"
else:
image_suffix = "_original_color.png"
mode = "YCbCr"

image_path = os.path.splitext(args.generated_image)[0] + image_suffix

generated_image = imread(args.generated_image, mode="RGB")
img_width, img_height, _ = generated_image.shape

content_image = imread(args.content_image, mode='YCbCr')
content_image = imread(args.content_image, mode=mode)
content_image = imresize(content_image, (img_width, img_height), interp='bicubic')

mask_transfer = args.mask is not None
Expand All @@ -59,7 +87,7 @@ def load_mask(mask_path, shape):
else:
mask_img = None

img = original_color_transform(content_image, generated_image, mask_img)
img = original_color_transform(content_image, generated_image, mask_img, args.hist_match, mode=mode)
imsave(image_path, img)

print("Image saved at path : %s" % image_path)
Expand Down

0 comments on commit 942c0ea

Please sign in to comment.