-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_text_feats_cache.py
77 lines (62 loc) · 2.26 KB
/
gen_text_feats_cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from net.vqa_bg import BG_VQA
import clip
from net.vqa_fg import FG_VQA
from visual_questions import QUESTIONS
from misc.cat_names import category_dict
from argparse import ArgumentParser
def parse_args():
parser = ArgumentParser()
parser.add_argument('dataset', type=str, choices=['voc', 'coco'])
parser.add_argument('--vqa_fg_file', type=str,)
parser.add_argument('--vqa_bg_file', type=str,)
parser.add_argument('--vqa_fg_cache_file', type=str,)
parser.add_argument('--vqa_bg_cache_file', type=str,)
# parser.add_argument('--prompt', type=str, default='a photo of {}.')
parser.add_argument('--clip', type=str, default='ViT-B/32', choices=['ViT-B/32', 'ViT-L/14'])
return parser.parse_args()
def gen_bg_text_feats_cache(device, clip_model, clip_name, vqa_file_path, cache_file_path):
bg_vqa_tool = BG_VQA(
vqa_file_path,
QUESTIONS['bg'],
clip_model,
clip_name,
device=device,
cache_path=cache_file_path,
)
bg_vqa_tool.gen_cache()
def gen_fg_text_feats_cache(device, clip_model, clip_name, vqa_file_path, cache_file_path, dataset_name):
fg_vqa_module = FG_VQA(
vqa_file_path,
QUESTIONS['fg'],
category_dict[dataset_name],
clip_model,
clip_name,
modify_cache=True,
cache_path=cache_file_path,
)
if __name__ == '__main__':
args = parse_args()
device = 'cuda:0'
clip_name = args.clip
clip_model, preprocess = clip.load(clip_name, device='cpu')
clip_model.to(device)
clip_model.eval()
if args.vqa_bg_file is not None:
print("\n\t\t=== Generating background text features cache ===\n")
gen_bg_text_feats_cache(
device=device,
clip_model=clip_model,
clip_name=clip_name,
vqa_file_path=args.vqa_bg_file,
cache_file_path=args.vqa_bg_cache_file,
)
if args.vqa_fg_file is not None:
print("\n\t\t=== Generating foreground text features cache ===\n")
gen_fg_text_feats_cache(
device=device,
clip_model=clip_model,
clip_name=clip_name,
vqa_file_path=args.vqa_fg_file,
cache_file_path=args.vqa_fg_cache_file,
dataset_name=args.dataset,
)