-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
73 lines (59 loc) · 2.49 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/python
import os
import argparse as arg
import shutil
import torch
import numpy as np
import cv2
from PIL import Image
from glob import glob
from data import load_testloader
from model import DenseDepth
from utils import colorize, DepthNorm, AverageMeter, load_images
from losses import ssim as ssim_criterion
from losses import depth_loss as gradient_criterion
def main():
parser = arg.ArgumentParser(description="Test the model that has been trained")
parser.add_argument("--checkpoint", "-c", type=str, help="path to checkpoint")
parser.add_argument("--device", "-d", type=str, default="cuda")
parser.add_argument("--data", type=str, default="examples/", help="Path to dataset zip file")
parser.add_argument("--cmap", type=str, default="gray", help="Colormap for the predictions")
args = parser.parse_args()
if len(args.checkpoint) and not os.path.isfile(args.checkpoint):
raise FileNotFoundError("{} no such file".format(args.checkpoint))
device = torch.device("cuda" if args.device == "cuda" else "cpu")
print("Using device: {}".format(device))
# Initializing the model and loading the pretrained model
model = DenseDepth(encoder_pretrained=False)
ckpt = torch.load(args.checkpoint, map_location=torch.device(device))
model.load_state_dict(ckpt["model_state_dict"])
model = model.to(device)
print("model load from checkpoint complete ...")
# Get Test Images
img_list = glob(args.data + "*.png")
# making processed image directory
try:
os.mkdir("examples/processed/")
except FileExistsError:
shutil.rmtree("examples/processed/")
os.mkdir("examples/processed/")
pass
save_path = "examples/processed/"
# Set model to eval mode
model.eval()
print(f"Number of images detected: {len(img_list)}")
# Begin testing loop
print("Begin Test Loop ...")
for idx, img_name in enumerate(img_list):
img = load_images([img_name])
img = torch.Tensor(img).float().to(device)
print("Processing {}, Tensor Shape: {}".format(img_name, img.shape))
with torch.no_grad():
preds = DepthNorm(model(img).squeeze(0))
output = colorize(preds.data, cmap=args.cmap)
output = output.transpose((1, 2, 0))
cv2.imwrite(save_path + os.path.basename(img_name).split(".")[0] + "_result.png", output)
print("Processing {} done.".format(img_name))
if __name__ == "__main__":
print("Using torch version: ", torch.__version__)
main()