diff --git a/main_test_swinir.py b/main_test_swinir.py index f06b5f39d..cebcd8805 100644 --- a/main_test_swinir.py +++ b/main_test_swinir.py @@ -10,6 +10,9 @@ from models.network_swinir import SwinIR as net from utils import util_calculate_psnr_ssim as util +def dbg(tag, arr): + print(f'{tag}: shape={arr.shape}, dtype={arr.dtype}, ' + f'range=({arr.min():.3f}, {arr.max():.3f})') def main(): parser = argparse.ArgumentParser() @@ -24,10 +27,13 @@ def main(): parser.add_argument('--large_model', action='store_true', help='use large model, only provided for real image sr') parser.add_argument('--model_path', type=str, default='model_zoo/swinir/001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth') + + parser.add_argument('--endoscope_data', type=bool, default=True, help='use endoscope data, which you need to pass in LR and HR folder where image with same name will be paired') parser.add_argument('--folder_lq', type=str, default=None, help='input low-quality test image folder') parser.add_argument('--folder_gt', type=str, default=None, help='input ground-truth test image folder') parser.add_argument('--tile', type=int, default=None, help='Tile size, None for no tile during testing (testing as a whole)') parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles') + parser.add_argument('--save_dir', type=str, default=None, help='output directory') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -47,6 +53,9 @@ def main(): # setup folder and path folder, save_dir, border, window_size = setup(args) + if args.endoscope_data: + hq_folder = folder[1] + folder = folder[0] os.makedirs(save_dir, exist_ok=True) test_results = OrderedDict() test_results['psnr'] = [] @@ -59,14 +68,25 @@ def main(): for idx, path in enumerate(sorted(glob.glob(os.path.join(folder, '*')))): # read image - imgname, img_lq, img_gt = get_image_pair(args, path) # image to HWC-BGR, float32 + if not args.endoscope_data: + imgname, img_lq, img_gt = get_image_pair(args, path) # image to HWC-BGR, float32 + else: + imgname = path.split('/')[-1] + img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255. + if args.folder_gt: + img_gt_path = os.path.join(hq_folder, imgname) + if os.path.exists(img_gt_path): + img_gt = cv2.imread(img_gt_path, cv2.IMREAD_COLOR).astype(np.float32) / 255. + else: + img_gt = None + img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)) # HCW-BGR to CHW-RGB img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device) # CHW-RGB to NCHW-RGB # inference with torch.no_grad(): # pad input image to be a multiple of window_size - _, _, h_old, w_old = img_lq.size() + _, _, h_old, w_old = img_lq.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :] @@ -194,31 +214,45 @@ def define_model(args): def setup(args): # 001 classical image sr/ 002 lightweight image sr + save_dir = args.save_dir if args.task in ['classical_sr', 'lightweight_sr']: - save_dir = f'results/swinir_{args.task}_x{args.scale}' - folder = args.folder_gt + if not args.save_dir: + save_dir = f'results/swinir_{args.task}_x{args.scale}' + if args.endoscope_data: + folder = [args.folder_lq, args.folder_gt] + else: + folder = args.folder_gt border = args.scale window_size = 8 # 003 real-world image sr elif args.task in ['real_sr']: - save_dir = f'results/swinir_{args.task}_x{args.scale}' + if not args.save_dir: + save_dir = f'results/swinir_{args.task}_x{args.scale}' if args.large_model: save_dir += '_large' - folder = args.folder_lq + if args.endoscope_data: + folder = [args.folder_lq, args.folder_gt] + else: + folder = args.folder_lq border = 0 window_size = 8 # 004 grayscale image denoising/ 005 color image denoising elif args.task in ['gray_dn', 'color_dn']: - save_dir = f'results/swinir_{args.task}_noise{args.noise}' - folder = args.folder_gt + if not args.save_dir: + save_dir = f'results/swinir_{args.task}_noise{args.noise}' + if args.endoscope_data: + folder = [args.folder_lq, args.folder_gt] + else: + folder = args.folder_gt border = 0 window_size = 8 # 006 JPEG compression artifact reduction elif args.task in ['jpeg_car', 'color_jpeg_car']: - save_dir = f'results/swinir_{args.task}_jpeg{args.jpeg}' + if not args.save_dir: + save_dir = f'results/swinir_{args.task}_jpeg{args.jpeg}' folder = args.folder_gt border = 0 window_size = 7