forked from alterzero/STARnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
upscale
executable file
·292 lines (248 loc) · 12 KB
/
upscale
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
#!/usr/bin/env python
# This script takes a directory of *.png input frame images, upscales them
# using the FBPNSR model pre-trained weights, and dumps the resulting frame
# images in the given output directory.
#
# >upscale --help
#??? Support multiple scenes? Support interpolation?
#??? Once ffmpeg is upgraded: Test -frame_pts true (is it pts or pts_time?)
#??? -vsync is deprecated. Use -fps_mode
import argparse
import sys, os, time
import re, glob
import torch
import torch.utils.data as data
import PIL, cv2
from fbpn_sr_rbpn_v4_ref import Net as FBPNSR_RBPN_V4_REF
from data import transform
from dataset import get_flow
from torch.autograd import Variable
import utils
from math import ceil
import functools
from ffmpeg import ffmpeg_extract_frames, ffmpeg_replace_frames
print = functools.partial(print, flush=True) # Auto-flush output
parser = argparse.ArgumentParser ( description = 'Super-resolution: Upscale and interpolate video using STARnet' )
parser.add_argument ( 'input', help='input video file or frames directory' )
parser.add_argument ( 'output', nargs='?', help='output frames directory' )
parser.add_argument ( '--scene',
metavar='[[HH:]MM:]SS[.ms]|# -/+ [[HH:]MM:]SS[.ms]|#',
help='video file time or frame range' )
parser.add_argument ( '--interpolate', dest='interpolate', action='store_true',
help='save additional interpolated frames' )
parser.add_argument ( '--no-resume', dest='resume', action='store_false',
help='overwrite existing output files' )
parser.add_argument ( '--no-cuda', dest='cuda', action='store_false',
help='use CPU even when GPU is available' )
parser.add_argument ( '--gpus', type=int, default=1, metavar='#',
help = 'number of GPUs to use' )
parser.add_argument ( '--threads', type=int, default=1, metavar='#',
help = 'number of threads to use' )
parser.add_argument ( '--testBatchSize', type=int, default=1, metavar='#',
help='testing batch size' )
parser.add_argument ( '--chop_forward', action='store_true' )
parser.add_argument ( '--debug', type=int, choices=range(3), default=0 )
opt = parser.parse_args()
opt.zoom = 4 # Hard-coded for STARnet
if opt.cuda: opt.cuda = torch.cuda.is_available()
if opt.scene:
ts = '(?:(?:(?:\d+:)?[0-5]?\d):[0-5]?|0\d*)\d(?:\.\d*)?|(?:(?:(?:\d+:)?[0-5]?\d)?:[0-5]?|\d*)\d\.\d*'
opt.scene = re.search(
f"^(?:(?P<st>{ts})(?:-(?P<et1>{ts})|-(?P<ef1>\d+)|\+(?P<tf1>\d+)" \
+f"|\+(?P<tt>{ts}))?" \
+f"|(?P<sf>\d+)(?:-(?P<et2>{ts})|-(?P<ef2>\d+)|\+(?P<tf2>\d+))?)$",
opt.scene)
if not opt.scene:
raise ValueError("--scene must be a valid time or frame range")
opt.scene = opt.scene.groupdict()
for v in ['et','ef','tf']:
opt.scene[v] = opt.scene[v+'1'] if opt.scene[v+'2'] is None \
else opt.scene[v+'2']
opt.scene.pop(v+'1',None)
opt.scene.pop(v+'2',None)
input_dir = opt.input
if not os.path.exists ( input_dir ) or \
( opt.scene and not os.path.isfile ( input_dir ) ):
raise FileNotFoundError(f"Input '{input_dir}' not found.")
elif os.path.isfile ( input_dir ):
input_dir = os.path.join ( os.path.dirname ( input_dir ),
'.'+os.path.basename ( input_dir )+'.inp' )
else:
opt.input = re.sub (r'/$', '', opt.input )
output_dir = opt.output
if opt.output is not None and os.path.isfile ( opt.input ) and \
( os.path.isfile ( opt.output ) or \
re.search(r'(^|/)[^.].*\.[^/]+$',opt.output) ):
output_dir = None
else:
opt.output = ''
if output_dir is None:
output_dir = os.path.join ( os.path.dirname ( opt.input ),
'.'+os.path.basename ( opt.input )+'.out' )
if input_dir == output_dir: opt.resume = False
if opt.debug > 1: print('opt =', opt)
if os.path.isfile ( opt.input ): # Extract frames using FFmpeg
ffmpeg_extract_frames(opt, input_dir)
if not os.path.exists(output_dir): os.makedirs(output_dir)
if opt.debug:
print('Setting up model: FBPNSR')
else:
print('Loading model: FBPNSR ...', end='')
pretrained_weights = os.path.join ( os.path.dirname ( sys.argv[0] ),
f"weights/pretrained/rbpn_pretrained_F2_4x.pth" )
if opt.debug:
print(f" Loading pre-trained weights: file = '{pretrained_weights}'")
model = FBPNSR_RBPN_V4_REF ( base_filter=256, feat=64, num_stages=3,
n_resblock=5,
weights=pretrained_weights, scale_factor=4 )
if opt.cuda:
model = torch.nn.DataParallel(model, device_ids=range(opt.gpus))
else:
model = torch.nn.DataParallel(model, device_ids=['cpu'])
pretrained_weights = os.path.join ( os.path.dirname ( sys.argv[0] ),
f"weights/FBPNSR_RBPN_V4_REF_Lr_STAR_ST.pth" )
if opt.debug:
print(f" Loading pre-trained weights: file = '{pretrained_weights}'")
model.load_state_dict(torch.load ( pretrained_weights,
map_location=lambda storage, loc: storage ))
if opt.cuda: model = model.cuda(0)
model.eval()
if not opt.debug: print(' done.')
# Define iterator object (ie, with __len__() and __getitem__()) for frames:
class frames_from_files(data.Dataset):
def __init__(self, input, output):
super(frames_from_files, self).__init__()
self.files = glob.glob(os.path.join(input,'*.png'))
self.files.sort()
# Preset the destination file names:
self.dests = { i:os.path.join(output,os.path.basename(self.files[i]))
for i in range(len(self.files)) }
if opt.interpolate:
for i in range(len(self.files)-1):
self.dests[i+0.5] = re.sub (r'(\.[^.]+)$', r'z1\g<1>', self.dests[i] )
self.images = {}
def var(self, image):
with torch.no_grad():
if opt.cuda:
return Variable(image).cuda(0)
else:
return Variable(image)
# Cache images:
def load_img(self, n):
if n >= len(self.files):
return False
elif n not in self.images.keys():
if opt.debug: print(f" Loading: file = '{self.files[n]}'; frame #={n}")
self.images[n] = { 'raw': PIL.Image.open(self.files[n]).convert('RGB') }
self.images[n]['var'] = self.var(transform()( self.images[n]['raw'] ))
return self.images[n]
def save_img(self, img, i):
save_img = img.squeeze().clamp(0, 1).numpy().transpose(1,2,0)
cv2.imwrite(self.dests[i], cv2.cvtColor(save_img*255, cv2.COLOR_BGR2RGB),
[cv2.IMWRITE_PNG_COMPRESSION, 0])
def __len__(self):
return len(self.files)-1
# A frame consists of an input image to upscale, the next image to use
# for context, and their corresponding flows:
def __getitem__(self, i):
if opt.resume:
if os.path.exists ( self.dests[i] ) and \
( not opt.interpolate or os.path.exists ( self.dests[i+0.5] ) ) and\
( i != self.__len__()-1 or os.path.exists ( self.dests[i+1] ) ):
if opt.debug:
print(f" Skipping: file = '{self.files[i]}'; frame #={i};"
+f" all destinations for '{self.dests[i]}' exist")
return [ 0, 0, 0, 0, -1 ]
input = self.load_img(i)['var'] # The image to upscale
next = self.load_img(i+1)['var'] # The next frame
flowi2n = get_flow(self.load_img(i)['raw'],self.load_img(i+1)['raw'])
flowi2n = self.var(torch.from_numpy(flowi2n.transpose(2,0,1))).float()
flown2i = get_flow(self.load_img(i+1)['raw'],self.load_img(i)['raw'])
flown2i = self.var(torch.from_numpy(flown2i.transpose(2,0,1))).float()
return [ input, next, flowi2n, flown2i, i ]
if opt.debug:
print('Processing frames...')
else:
print(f"Processing frames in '{input_dir}'", end='')
frames = frames_from_files ( input_dir, output_dir )
for frame in data.DataLoader ( dataset=frames, num_workers=opt.threads,
batch_size=opt.testBatchSize, shuffle=False ):
i = frame.pop().numpy()[0] # Hacky way to get index...
if i < 0:
if not opt.debug: print('.', end='')
continue
if opt.debug: print(f" Generating: file = '{frames.files[i]}'"
+f" -> '{frames.dests[i]}'")
t0 = time.time()
with torch.no_grad():
if opt.chop_forward:
predictions = chop_forward(*frame, model)
else:
predictions = model(*frame, train=False)
# predictions = high-resolution interpolated (i+0.5) image,
# high-resolution input (i) image,
# high-resolution next (i+1) image,
# low-resolution interpolated (i+0.5) image
if opt.debug: print(f" Saving: file = '{frames.dests[i]}'"
+f" took {time.time() - t0 :0.1f}s to generate")
predictions[1][0] = utils.denorm(predictions[1][0].cpu().data,vgg=True)
frames.save_img ( predictions[1][0], i )
if opt.interpolate:
if opt.resume and os.path.exists ( frames.dests[i+0.5] ):
if opt.debug: print(f" Skipping: file = '{frames.dests[i+0.5]}'")
else:
if opt.debug: print(f" Saving: file = '{frames.dests[i+0.5]}'")
predictions[0][0] = utils.denorm(predictions[0][0].cpu().data,vgg=True)
frames.save_img ( predictions[0][0], i+0.5 )
if i == frames.__len__()-1:
if opt.resume and os.path.exists ( frames.dests[i+1] ):
if opt.debug: print(f" Skipping: file = '{frames.dests[i+1]}'")
else:
if opt.debug: print(f" Saving: file = '{frames.dests[i+1]}'")
predictions[2][0] = utils.denorm(predictions[2][0].cpu().data,vgg=True)
frames.save_img ( predictions[2][0], i+1 )
if not opt.debug: print('.', end='')
if not opt.debug: print('')
if opt.output: # Replace frames using FFmpeg (experimental)
ffmpeg_replace_frames(opt, input_dir, output_dir)
def chop_forward(t_im1, t_im2, t_flow_f, t_flow_b, iter, model, shave=8, min_size=160000, nGPUs=opt.gpus):
b, c, h, w = t_im1.size()
h_half, w_half = h // 2, w // 2
h_size, w_size = h_half + shave, w_half + shave
mod_size = 4
if h_size%mod_size:
h_size = ceil(h_size/mod_size)*mod_size
if w_size%mod_size:
w_size = ceil(w_size/mod_size)*mod_size
inputlist = [
[t_im1[:, :, 0:h_size, 0:w_size], t_im2[:, :, 0:h_size, 0:w_size], t_flow_f[:, :, 0:h_size, 0:w_size], t_flow_b[:, :, 0:h_size, 0:w_size], iter],
[t_im1[:, :, 0:h_size, (w - w_size):w],t_im2[:, :, 0:h_size, (w - w_size):w],t_flow_f[:, :, 0:h_size, (w - w_size):w],t_flow_b[:, :, 0:h_size, (w - w_size):w],iter ],
[t_im1[:, :, (h - h_size):h, 0:w_size],t_im2[:, :, (h - h_size):h, 0:w_size],t_flow_f[:, :, (h - h_size):h, 0:w_size],t_flow_b[:, :, (h - h_size):h, 0:w_size],iter ],
[t_im1[:, :, (h - h_size):h, (w - w_size):w],t_im2[:, :, (h - h_size):h, (w - w_size):w],t_flow_f[:, :, (h - h_size):h, (w - w_size):w],t_flow_b[:, :, (h - h_size):h, (w - w_size):w],iter ]]
if w_size * h_size < min_size:
outputlist = []
for i in range(0, 4, nGPUs):
with torch.no_grad():
input_batch = inputlist[i]#torch.cat(inputlist[i:(i + nGPUs)], dim=0)
output_batch = model(input_batch[0], input_batch[1], input_batch[2], input_batch[3], train=False)
outputlist.extend(output_batch.chunk(nGPUs, dim=0))
else:
outputlist = [
chop_forward(patch[0], patch[1], patch[2],patch[3],patch[4], model, shave, min_size, nGPUs) \
for patch in inputlist]
scale=1
h, w = scale * h, scale * w
h_half, w_half = scale * h_half, scale * w_half
h_size, w_size = scale * h_size, scale * w_size
shave *= scale
with torch.no_grad():
output = Variable(t_im1.data.new(b, c, h, w))
output[:, :, 0:h_half, 0:w_half] \
= outputlist[0][:, :, 0:h_half, 0:w_half]
output[:, :, 0:h_half, w_half:w] \
= outputlist[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
output[:, :, h_half:h, 0:w_half] \
= outputlist[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
output[:, :, h_half:h, w_half:w] \
= outputlist[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
return output