-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
365 lines (330 loc) · 15 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import time
everything_start_time = time.time()
import os
import subprocess
import json
import argparse
import cv2
import numpy
import SRers
parser = argparse.ArgumentParser()
# Input/output file
parser.add_argument('-i', '--input', # Input file
type=str,
help='path of video to be converted')
parser.add_argument('-o', '--output', # Output file
type=str, default='default',
help='Specify output file name. Default: output.mp4')
parser.add_argument('-ot', '--output_type', # Output file type
type=str, choices=['video', 'npz', 'npy', 'tiff', 'png'], default='npy',
help='Output file type, -o needs to be a file and image sequence or npz needs to be a folder')
# Process type
parser.add_argument('-a', '--algorithm', type=str, default='EDVR', # 算法
choices=['EDVR', 'ESRGAN'], help='EDVR or ESRGAN')
parser.add_argument('-mn', '-model_name', type=str, default='mt4r',
choices=['ld', 'ldc', 'l4r', 'l4v', 'l4br', 'm4r', 'mt4r'],
help='ld: L Deblur, ldc: L Deblur Comp, l4r: L SR REDS x4, l4v: L SR vimeo90K 4x, '
'l4br: L SRblur REDS 4x, m4r: M woTSA SR REDS 4x, mt4r: M SR REDS 4x')
# Model directory
parser.add_argument('-md', '--model_path', # 模型路径
type=str, default='default',
help='path of checkpoint for pretrained model')
# Start/End frame
parser.add_argument('-st', '--start_frame', # 开始帧
type=int, default=1,
help='specify start frame (Start from 1)')
parser.add_argument('-ed', '--end_frame', # 结束帧
type=int, default=0,
help='specify end frame. Default: Final frame')
# FFmpeg
parser.add_argument('-fd', '--ffmpeg_dir', # FFmpeg路径
type=str, default='',
help='path to ffmpeg(.exe)')
parser.add_argument('-vc', '--vcodec', # 视频编码
type=str, default='h264',
help='Video codec')
parser.add_argument('-ac', '--acodec', # 音频编码
type=str, default='copy',
help='Audio codec')
parser.add_argument('-br', '--bit_rate', # 视频编码
type=str, default='',
help='Bit rate for output video')
parser.add_argument('-fps', # 目标帧率
type=float,
help='specify fps of output video. Default: original fps * sf.')
parser.add_argument('-mc', '--mac_compatibility', # 让苹果设备可以直接播放
type=bool, default=True,
help='If you want to play it on a mac with QuickTime or iOS, set this to True and the pixel '
'format will be yuv420p. ')
# Other
parser.add_argument('-bs', '--batch_size', # Batch Size
type=int, default=1,
help='Specify batch size for faster conversion. This will depend on your cpu/gpu memory. Default: 1')
parser.add_argument('-ec', '--empty_cache', # Batch Size
type=int, default=0,
help='Empty cache while processing, set to 1 if you get CUDA out of memory errors; If there\'s '
'the process is ok, setting to 1 will slow down the process. ')
# Temporary files
parser.add_argument('-tmp', '--temp_file_path', # 临时文件路径
type=str, default='tmp',
help='Specify temporary file path')
parser.add_argument('-rm', '--remove_temp_file', # 是否移除临时文件
type=bool, default=False,
help='If you want to keep temporary files, select True ')
args = parser.parse_args().__dict__
model_paths = {
'EDVR': {
'ld': 'BasicSR/experiments/pretrained_models/EDVR/EDVR_L_deblur_REDS_official-ca46bd8c.pth',
'ldc': 'BasicSR/experiments/pretrained_models/EDVR/EDVR_L_deblurcomp_REDS_official-0e988e5c.pth',
'l4v': 'BasicSR/experiments/pretrained_models/EDVR/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth',
'l4r': 'BasicSR/experiments/pretrained_models/EDVR/EDVR_L_x4_SR_REDS_official-9f5f5039.pth',
'l4br': 'BasicSR/experiments/pretrained_models/EDVR/EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth',
'm4r': 'BasicSR/experiments/pretrained_models/EDVR/EDVR_M_woTSA_x4_SR_REDS_official-1edf645c.pth',
'mt4r': 'BasicSR/experiments/pretrained_models/EDVR/EDVR_M_x4_SR_REDS_official-32075921.pth'
},
'ESRGAN': {
'test': 'ESRGAN.pth'
}
}
def listdir(folder):
disallow = ['.DS_Store', '.ipynb_checkpoints', '$RECYCLE.BIN', 'Thumbs.db', 'desktop.ini']
files = []
for file in os.listdir(folder):
if file not in disallow and file[:2] != '._':
files.append(file)
files.sort()
return files
class data_loader:
def __init__(self, input_dir, input_type, start_frame):
self.input_type = input_type
self.input_dir = input_dir
self.start_frame = start_frame
self.sequence_read_funcs = {'is': cv2.imread,
'npz': lambda path: numpy.load(path)['arr_0'],
'npy': numpy.load
}
self.read = self.video_func if self.input_type == 'video' else self.sequence_func
if input_type == 'video':
self.cap = cv2.VideoCapture(input_dir)
self.cap.set(1, self.start_frame)
self.fps = self.cap.get(5)
self.frame_count = int(self.cap.get(7))
self.height = int(self.cap.get(4))
self.width = int(self.cap.get(3))
else:
self.count = -1
self.files = [f'{input_dir}/{f}' for f in listdir(input_dir)[self.start_frame:]]
self.frame_count = len(self.files)
self.img = self.sequence_read_funcs[input_type](self.files[0]).shape
self.height = self.img[0]
self.width = self.img[1]
del self.img
self.read = self.video_func if self.input_type == 'video' else self.sequence_func
def video_func(self):
return self.cap.read()
def sequence_func(self):
self.count += 1
if self.count < self.frame_count:
img = self.sequence_read_funcs[self.input_type](self.files[self.count])
if img is not None:
return True, img
return False, None
def close(self):
if self.input_type == 'video':
self.cap.close()
def data_writer(output_type):
return {'tiff': lambda path, img: cv2.imwrite(path + '.tiff', img),
'png': lambda path, img: cv2.imwrite(path + '.png', img),
'npz': numpy.savez_compressed,
'npy': numpy.save
}[output_type]
def detect_input_type(input_dir): # 检测输入类型
if os.path.isfile(input_dir):
if os.path.splitext(input_dir)[1].lower() == '.json':
input_type_ = 'continue'
else:
input_type_ = 'video'
else:
files = listdir(input_dir)
if os.path.splitext(files[0])[1].lower() == '.npz':
input_type_ = 'npz'
elif os.path.splitext(files[0])[1].lower() == '.npy':
input_type_ = 'npy'
elif os.path.splitext(files[0])[1].replace('.', '').lower() in \
['dpx', 'jpg', 'jpeg', 'exr', 'psd', 'png', 'tif', 'tiff']:
input_type_ = 'is'
else:
input_type_ = 'mix'
return input_type_
def check_output_dir(dire, ext=''):
if not os.path.exists(os.path.split(dire)[0]): # If mother directory doesn't exist
os.makedirs(os.path.split(dire)[0]) # Create one
if os.path.exists(dire + ext): # If target file/folder exists
count = 2
while os.path.exists(f'{dire}_{count}{ext}'):
count += 1
dire = f'{dire}_{count}{ext}'
else:
dire = f'{dire}{ext}'
if not ext: # Output as folder
os.mkdir(dire)
return dire
def second2time(second: float):
m, s = divmod(second, 60)
h, m = divmod(m, 60)
t = '%d:%02d:%05.2f' % (h, m, s)
return t
input_type = detect_input_type(args['input'])
if input_type == 'mix':
processes = listdir(args['input'])
processes = [os.path.join(args['input'], process) for process in processes]
else:
processes = [args['input']]
# Extra work
args['start_frame'] -= 1
for input_file_path in processes:
input_type = detect_input_type(input_file_path)
if input_type != 'continue':
input_file_name_list = list(os.path.split(input_file_path))
input_file_name_list.extend(os.path.splitext(input_file_name_list[1]))
input_file_name_list.pop(1)
temp_file_path = check_output_dir(os.path.join(args['temp_file_path'], input_file_name_list[1]))
video = data_loader(input_file_path, input_type, args['start_frame'])
frame_count = video.frame_count
frame_count_len = len(str(frame_count))
if args['fps']:
fps = args['fps']
elif input_type == 'video':
fps = video.fps
else:
fps = 30
# Start/End frame
if args['end_frame'] == 0 or args['end_frame'] == frame_count or args['end_frame'] > frame_count:
end_frame = frame_count
else:
end_frame = args['end_frame'] + 1
if args['start_frame'] == 0 or args['start_frame'] >= frame_count:
start_frame = 1
else:
start_frame = args['start_frame']
if args['model_path'] == 'default': # 模型路径
model_path = model_paths[args['algorithm']][args['mn']]
else:
model_path = args['model_path']
output_type = args['output_type']
output_dir = args['output']
if output_dir == 'default':
output_dir = f"{input_file_name_list[0]}/{input_file_name_list[1]}_{args['algorithm']}"
if output_type == 'video':
if input_file_name_list[2]:
ext = input_file_name_list[2]
else:
ext = '.mp4'
else:
output_dir, ext = os.path.splitext(output_dir)
if not os.path.exists(os.path.split(output_dir)[0]):
os.makedirs(os.path.split(output_dir)[0])
if output_type == 'video':
dest_path = check_output_dir(os.path.splitext(output_dir)[0], ext)
output_dir = f'{temp_file_path}/tiff'
output_type = 'tiff'
else:
dest_path = False
os.makedirs(output_dir, exist_ok=True)
cag = {'input_file_path': input_file_path,
'input_type': input_type,
'empty_cache': args['empty_cache'],
'model_path': model_path,
'temp_folder': temp_file_path,
'algorithm': args['algorithm'],
'frame_count': frame_count,
'frame_count_len': len(str(video.frame_count)),
'height': video.height,
'width': video.width,
'start_frame': start_frame,
'end_frame': end_frame,
'model_name': args['mn'],
'batch_size': args['batch_size'],
'output_type': output_type,
'output_dir': output_dir,
'dest_path': dest_path,
'mac_compatibility': args['mac_compatibility'],
'ffmpeg_dir': args['ffmpeg_dir'],
'fps': fps,
'vcodec': args['vcodec'],
'acodec': args['acodec'],
'remove_temp_file': args['remove_temp_file']
}
with open(f'{temp_file_path}/process_info.json', 'w') as f:
json.dump(cag, f)
else:
with open(input_file_path, 'r') as f_:
cag = json.load(f_)
start_frame = len(listdir(cag['output_dir'])) // cag['sf']
video = data_loader(cag['input_file_path'], cag['input_type'], start_frame - 1)
if cag['empty_cache']:
os.environ['CUDA_EMPTY_CACHE'] = str(cag['empty_cache'])
# Model checking
if not os.path.exists(cag['model_path']):
print(f"Model {cag['model_path']} doesn't exist, exiting")
exit(1)
# Start frame
batch_count = (cag['frame_count'] - start_frame + 1) // cag['batch_size']
if (cag['frame_count'] - start_frame) % cag['batch_size']:
batch_count += 1
# Super resolution
SRer = SRers.__dict__[cag['algorithm']].SRer(cag['model_name'], cag['model_path'], cag['height'], cag['width'])
SRer.init_batch(video)
save = data_writer(cag['output_type'])
timer = 0
start_time = time.time()
try:
for i in range(batch_count):
out = SRer.sr(video.read())
save(f"{cag['output_dir']}/{str(i).zfill(cag['frame_count_len'])}", out)
time_spent = time.time() - start_time
start_time = time.time()
if i == 0:
initialize_time = time_spent
print(f'Initialized and processed frame 1/{batch_count} | '
f'{batch_count - i - 1} frames left | '
f'Time spent: {round(initialize_time, 2)}s',
end='')
else:
timer += time_spent
frames_processes = i + 1
frames_left = batch_count - frames_processes
print(f'\rProcessed batch {frames_processes}/{batch_count} | '
f"{frames_left} {'batches' if frames_left > 1 else 'batch'} left | "
f'Time spent: {round(time_spent, 2)}s | '
f'Time left: {second2time(frames_left * timer / i)} | '
f'Total time spend: {second2time(timer + initialize_time)}', end='', flush=True)
except KeyboardInterrupt:
print('\nCaught Ctrl-C, exiting. ')
exit(256)
del video, SRer
print(f'\r{os.path.split(input_file_path)[1]} done! Total time spend: {second2time(timer + initialize_time)}', flush=True)
# Video post process
if cag['dest_path']:
# Mac compatibility
mac_compatibility = ['-pix_fmt', 'yuv420p'] if cag['mac_compatibility'] else ''
if 'hevc' in cag['vcodec']:
mac_compatibility.extend(['-vtag', 'hvc1'])
# Execute command
cmd = [f"'{os.path.join(cag['ffmpeg_dir'], 'ffmpeg')}'",
'-loglevel error', '-vsync 0',
'-r', str(cag['fps']),
'-pattern_type glob',
'-i', f"'{os.path.join(cag['temp_folder'], 'tiff/*.tiff')}'",
'-vcodec', cag['vcodec'], *mac_compatibility, '-crf 20',
f"'{cag['dest_path']}'"]
has_audio = 'streams' in eval(subprocess.getoutput(f"ffprobe -v quiet -show_streams -select_streams a -print_format json '{cag['input_file_path']}'")).keys()
if cag['start_frame'] == 1 and cag['end_frame'] == 0 and has_audio:
cmd.insert(1, '-thread_queue_size 1048576')
cmd.insert(3, f"-vn -i '{cag['input_file_path']}'")
cmd.insert(7, f"-acodec {cag['acodec']}")
cmd = ' '.join(cmd)
os.system(cmd)
if cag['remove_temp_file']:
rmtree(cag['temp_folder'])
print(time.time() - everything_start_time)