forked from dusty-nv/jetson-containers
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclip.py
164 lines (129 loc) · 6.66 KB
/
clip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python3
import os
import time
import PIL
import clip
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from .utils import AttrDict, load_image, download_model, print_table
class CLIPEmbedding():
"""
CLIP feature extractor and projector for generating image embeddings.
"""
def __init__(self, model='ViT-L/14@336px', dtype=np.float32, crop=True, model_cache='/data/models/clip', **kwargs):
"""
Parameters:
model (str) -- name or path to CLIP model, one of:
'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64',
'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
"""
self.config = AttrDict(name=model)
self.image_stats = AttrDict()
self.text_stats = AttrDict()
self.extensions = ('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.stream = None
dtype = np.dtype(dtype)
if dtype == np.float32:
self.config.dtype = torch.float32
elif dtype == np.float16:
self.config.dtype = torch.float16
else:
raise ValueError(f"unsupported datatype: {dtype}")
print(f'-- loading CLIP {model}')
self.model, _ = clip.load(
model,
device=self.device,
jit=False,
download_root=model_cache
)
self.config.crop = crop
self.config.input_shape = (self.model.visual.input_resolution, self.model.visual.input_resolution)
self.image_model = self.model.visual
"""
# TensorRT disabled right now for 8.4 - needs PyTorch 2.1, onnxruntime, not much faster, wrong FP16 results
trt_path = os.path.join(model_cache, model.replace('/','-').replace('@','-') + '-trt.pth')
if os.path.isfile(trt_path):
print(f"-- loading TensorRT model from {trt_path}")
self.image_model = torch2trt.TRTModule()
self.image_model.load_state_dict(torch.load(trt_path))
else:
# needs PyTorch 2.1 and onnxruntime
self.image_model = torch2trt.torch2trt(
self.model.visual.cpu().float(), # put on CPU for onnx export
[torch.ones(1, 3, *self.config.input_shape, dtype=torch.float32)], # TRT expects FP32 input
fp16_mode=(self.config.dtype == torch.float16),
log_level=tensorrt.Logger.VERBOSE,
use_onnx=True,
onnx_opset=14,
)
print(f"-- saving TensorRT model to {trt_path}")
torch.save(self.image_model.state_dict(), trt_path)
"""
# Pre-processing is able to use GPU with torchvision (cropping is optional)
# https://github.com/openai/CLIP/blob/a1d071733d7111c9c014f024669f959182114e33/clip/clip.py#L79
self.preprocessor = torch.nn.Sequential()
self.preprocessor.append(
transforms.Resize(
self.config.input_shape[0] if crop else self.config.input_shape,
interpolation=transforms.InterpolationMode.BICUBIC# BILINEAR
)
)
if crop:
self.preprocessor.append(transforms.CenterCrop(self.config.input_shape[0]))
print("-- image cropping enabled")
self.preprocessor.append(transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
self.preprocessor.append(transforms.ConvertImageDtype(self.config.dtype))
self.preprocessor = self.preprocessor.eval().to(self.device)
print(self.model)
print(f"-- {self.config.name} warmup")
for i in range(2):
self.embed_image(PIL.Image.new('RGB', self.config.input_shape, (255,255,255)))
print_table(self.config)
def embed_image(self, image, return_tensors='pt', **kwargs):
if isinstance(image, str):
image = load_image(image) #api='torchvision') # torchvision not any faster, and needs built with PNG
time_begin_pre = time.perf_counter()
with torch.cuda.StreamContext(self.stream), torch.inference_mode():
if isinstance(image, PIL.Image.Image) or isinstance(image, np.ndarray):
image = transforms.functional.to_tensor(image)
#else:
# image = image.to(device=self.device, dtype=self.config.dtype) / 255.0 # needed when load_image(api='torchvision')
image = image.to(device=self.device, dtype=self.config.dtype)
image = self.preprocessor(image).unsqueeze(0)
time_begin_enc = time.perf_counter()
output = self.image_model(image) #self.model.encode_image(image)
output = self.model.logit_scale.exp() * output
time_end_enc = time.perf_counter()
self.config.output_shape = output.shape
self.image_stats.clip_time = time_end_enc - time_begin_pre
self.image_stats.clip_rate = 1.0 / self.image_stats.clip_time
self.image_stats.preprocess_time = time_begin_enc - time_begin_pre
self.image_stats.encode_time = time_end_enc - time_begin_enc
self.image_stats.input_shape = f"({image.shape[-1]},{image.shape[-2]}) -> {self.config.input_shape}"
self.image_stats.output_shape = self.config.output_shape
if return_tensors == 'np':
return output.detach().cpu().numpy()
elif return_tensors == 'pt':
return output
else:
raise ValueError(f"return_tensors should be 'np' or 'pt' (was '{return_tensors}')")
def embed_text(self, text, return_tensors='pt', **kwargs):
if isinstance(text, str) or (isinstance(text, list) and isinstance(text[0], str)):
time_begin = time.perf_counter()
text = clip.tokenize(text).to(self.device)
self.text_stats.tokens = text.shape
self.text_stats.tokens_time = time.perf_counter() - time_begin
time_begin = time.perf_counter()
with torch.cuda.StreamContext(self.stream), torch.inference_mode():
output = self.model.encode_text(text)
self.text_stats.encode_time = time.perf_counter() - time_begin
self.text_stats.output_shape = output.shape
if return_tensors == 'np':
return output.detach().cpu().numpy()
elif return_tensors == 'pt':
return output
else:
raise ValueError(f"return_tensors should be 'np' or 'pt' (was '{return_tensors}')")