Skip to content

1110812

crodis-strife edited this page Aug 13, 2022 · 13 revisions

Reference:

https://github.com/richzhang/colorization-pytorch/blob/master/util/util.py

https://github.com/richzhang/colorization-pytorch/blob/master/options/base_options.py

import numpy as np
import torch
from PIL import Image
import torchvision.transforms as transforms

def get_colorization_data(data_raw, ab_thresh=5., p=.125, num_points=100000):
           
    ab_norm = 110. # colorization normalization factor
          
    use_avg=True
    samp='normal'

    N,C,H,W = data_raw.shape
    
    hint = torch.ones_like(data_raw)    
    
    sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9, ]
    mask_cent = .5

    for nn in range(N):
        pp = 0
        cont_cond = True
        while(cont_cond):
            if(num_points is None): # draw from geometric
                # embed()
                cont_cond = np.random.rand() < (1-p)
            else: # add certain number of points
                cont_cond = pp < num_points
            if(not cont_cond): # skip out of loop if condition not met
                continue

            P = np.random.choice(sample_Ps) # patch size

            # sample location
            if(samp=='normal'): # geometric distribution
                h = int(np.clip(np.random.normal( (H-P+1)/2., (H-P+1)/4.), 0, H-P))
                w = int(np.clip(np.random.normal( (W-P+1)/2., (W-P+1)/4.), 0, W-P))
                #print("h=",h,"w=",w)
            else: # uniform distribution
                h = np.random.randint(H-P+1)
                w = np.random.randint(W-P+1)
                 
            hint[nn,:,h:h+4,w:w+4] = data_raw[nn,:,h:h+4,w:w+4]
              
            # increment counter
            pp+=1

    #print("pp=",pp)
    
    return hint


transform1 = transforms.Compose([
    transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
    ]
)

img = Image.open('test.jpg')         # PIL.Image
img2 = transform1(img)
img2 = img2.unsqueeze(0)

hint = get_colorization_data(img2)

img_3 = transforms.ToPILImage()(hint[0]).convert('RGB')
#img_3 = transforms.ToPILImage()(hint[0]).convert('L')


img_3.save("out.jpg")
Clone this wiki locally