-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy path__init__.py
96 lines (82 loc) · 4.44 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import clip
from PIL import Image
from torchvision import transforms
from .constants import CACHE_DIR
def get_model(model_name, device, root_dir=CACHE_DIR):
"""
Helper function that returns a model and a potential image preprocessing function.
"""
if "openai-clip" in model_name:
from .clip_models import CLIPWrapper
variant = model_name.split(":")[1]
model, image_preprocess = clip.load(variant, device=device, download_root=root_dir)
model = model.eval()
clip_model = CLIPWrapper(model, device)
return clip_model, image_preprocess
elif model_name == "blip-flickr-base":
from .blip_models import BLIPModelWrapper
blip_model = BLIPModelWrapper(root_dir=root_dir, device=device, variant="blip-flickr-base")
image_preprocess = transforms.Compose([
transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
return blip_model, image_preprocess
elif model_name == "blip-coco-base":
from .blip_models import BLIPModelWrapper
blip_model = BLIPModelWrapper(root_dir=root_dir, device=device, variant="blip-coco-base")
image_preprocess = transforms.Compose([
transforms.Resize((384, 384),interpolation=transforms.functional.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
return blip_model, image_preprocess
elif model_name == "xvlm-flickr":
from .xvlm_models import XVLMWrapper
xvlm_model = XVLMWrapper(root_dir=root_dir, device=device, variant="xvlm-flickr")
image_preprocess = transforms.Compose([
transforms.Resize((384, 384), interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),])
return xvlm_model, image_preprocess
elif model_name == "xvlm-coco":
from .xvlm_models import XVLMWrapper
xvlm_model = XVLMWrapper(root_dir=root_dir, device=device, variant="xvlm-coco")
image_preprocess = transforms.Compose([
transforms.Resize((384, 384), interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),])
return xvlm_model, image_preprocess
elif model_name == "flava":
from .flava import FlavaWrapper
flava_model = FlavaWrapper(root_dir=root_dir, device=device)
image_preprocess = None
return flava_model, image_preprocess
elif model_name == "NegCLIP":
import open_clip
from .clip_models import CLIPWrapper
path = os.path.join(root_dir, "negclip.pth")
if not os.path.exists(path):
print("Downloading the NegCLIP model...")
import gdown
gdown.download(id="1ooVVPxB-tvptgmHlIMMFGV3Cg-IrhbRZ", output=path, quiet=False)
model, _, image_preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained=path, device=device)
model = model.eval()
clip_model = CLIPWrapper(model, device)
return clip_model, image_preprocess
elif model_name == "coca":
import open_clip
from .clip_models import CLIPWrapper
model, _, image_preprocess = open_clip.create_model_and_transforms(model_name="coca_ViT-B-32", pretrained="laion2B-s13B-b90k", device=device)
model = model.eval()
clip_model = CLIPWrapper(model, device)
return clip_model, image_preprocess
elif "laion-clip" in model_name:
import open_clip
from .clip_models import CLIPWrapper
variant = model_name.split(":")[1]
model, _, image_preprocess = open_clip.create_model_and_transforms(model_name=variant, pretrained="laion2b_s34b_b79k", device=device)
model = model.eval()
clip_model = CLIPWrapper(model, device)
return clip_model, image_preprocess
else:
raise ValueError(f"Unknown model {model_name}")