-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.py
127 lines (91 loc) · 3.46 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import random
import numpy as np
import skimage.io as sio
import skimage.color as sc
import skimage.transform as st
from scipy import ndimage
import torch
from torchvision import transforms
def input2tuple(val,ndim,default,last_dim_chan=True):
#import pdb; pdb.set_trace()
if val is None: out_list = [default,]*ndim
elif isinstance(val,(list,tuple,np.ndarray)): out_list = list(val)
else: out_list = [val,]*ndim
if last_dim_chan: out_list[-1] = default
return tuple(out_list)
def downsample(X,strides=None,offsets=None,sigmas=None,last_dim_chan=True):
#import pdb; pdb.set_trace()
ndim = len(X.shape)
strides = input2tuple(strides,ndim,1,last_dim_chan)
offsets = input2tuple(offsets,ndim,0,last_dim_chan)
sigmas = input2tuple(sigmas,ndim,0,last_dim_chan)
# blur, set sigmas=0 or sigmas=None to "turn off" blurring
X_filt = ndimage.gaussian_filter(X,sigmas)
# subsample the image
X_ds = X_filt[tuple(slice(o,None,s) for o,s in zip(offsets,strides))]
return X_ds
def get_patch(img_in, img_hr_in, patch_size, scale):
ih, iw = img_hr_in.shape[:2]
tp = patch_size
ix = random.randrange(0, iw - tp + 1)
iy = random.randrange(0, ih - tp + 1)
#img_in = img_in[iy:iy + tp, ix:ix + tp, :]
img_hr_in = img_hr_in[iy:iy + tp, ix:ix + tp, :]
# Downsample
img_ds = downsample(img_hr_in, scale)
#return img_ds, img_in
return img_ds, img_hr_in
def get_patch_1D(img_in, img_hr_in, patch_size, scale):
ih, iw = img_hr_in.shape[:2]
tp = patch_size
ix = random.randrange(0, iw - tp + 1)
iy = random.randrange(0, ih - tp + 1)
#img_in = img_in[iy:iy + tp, ix:ix + tp, :]
img_hr_in = img_hr_in[iy:iy + tp, ix:ix + tp, :]
# Downsample
img_ds = img_hr_in[::scale, :, :]
#return img_ds, img_in
return img_ds, img_hr_in
def set_channel(l, n_channel):
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channel == 1 and c == 3:
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
elif n_channel == 3 and c == 1:
img = np.concatenate([img] * n_channel, 2)
return img
return [_set_channel(_l) for _l in l]
def np2Tensor(l, rgb_range):
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
tensor = torch.from_numpy(np_transpose).float()
#tensor.mul_(rgb_range / 255)
return tensor
return [_np2Tensor(_l) for _l in l]
def add_noise(x, noise='.'):
if noise is not '.':
noise_type = noise[0]
noise_value = int(noise[1:])
if noise_type == 'G':
noises = np.random.normal(scale=noise_value, size=x.shape)
noises = noises.round()
elif noise_type == 'S':
noises = np.random.poisson(x * noise_value) / noise_value
noises = noises - noises.mean(axis=0).mean(axis=0)
x_noise = x.astype(np.int16) + noises.astype(np.int16)
x_noise = x_noise.clip(0, 255).astype(np.uint8)
return x_noise
else:
return x
def augment(l, hflip=True, rot=True):
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip: img = img[:, ::-1, :]
if vflip: img = img[::-1, :, :]
if rot90: img = img.transpose(1, 0, 2)
return img
return [_augment(_l) for _l in l]