-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsam_test.py
90 lines (72 loc) · 3.42 KB
/
sam_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6'
import torch
import numpy as np
import cv2
import json
from tqdm import tqdm
from sam import sam_model_registry
from sam.utils.transforms import ResizeLongestSide
def main():
sam_checkpoint = "./checkpoints/chk_sam/finetune.pth"
model_type = "vit_b"
device = "cuda"
path = "/data/wangyh/data4/Datasets/shadow/video_new/visha4/test"
sam_model = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam_model.to(device=device)
sam_model.eval()
f = open("./dataset/sam_test.json", "r")
content = f.read()
meta = json.loads(content)
img_all = []
videolists = sorted(os.listdir(os.path.join(path, "images")))
for video in videolists:
v_path = os.path.join(path, "images", video)
imglist = sorted(os.listdir(v_path)) # 当前video的frame1
img_all = img_all + [os.path.join(v_path, img_file) for img_file in imglist]
lab_all = [p.replace("images", "labels").replace(".jpg", ".png") for p in img_all]
for i, img_path in tqdm(enumerate(img_all)):
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# label = cv2.imread(lab_all[i])
# label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
# transform
sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size) # 1024
resize_image = sam_trans.apply_image(image) # 等比例填充到一边为1024
image_tensor = torch.as_tensor(resize_image, device=device)
input_image_torch = image_tensor.permute(2, 0, 1).contiguous()[None, :, :, :]
input_image = sam_model.preprocess(input_image_torch)
original_image_size = image.shape[:2]
input_size = tuple(input_image_torch.shape[-2:])
video_name = img_path.split('/')[-2]
file_name = img_path.split('/')[-1].replace("jpg", "png")
bboxes = meta[video_name][file_name]['bbox']
bboxes = np.array(bboxes)
with torch.no_grad():
box = sam_trans.apply_boxes(bboxes, (original_image_size))
box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
image_embedding = sam_model.image_encoder(input_image)
sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_masks, iou_predictions = sam_model.mask_decoder(
image_embeddings=image_embedding,
image_pe=sam_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
low_res_masks = torch.sum(low_res_masks, dim=0, keepdim=True)
upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
mask_save = (upscaled_masks>0.5)[0].detach().squeeze(0).cpu().numpy()
mask_save = np.array(mask_save * 255).astype(np.uint8)
vi = img_path.split('/')[-2]
fi = img_path.split('/')[-1].split('.')[-2]+'.png'
os.makedirs(os.path.join("results", "sam", "labels", vi), mode=0o777, exist_ok=True)
cv2.imwrite(os.path.join("results", "sam", "labels", vi, fi), mask_save)
if __name__ == '__main__':
main()