13
13
14
14
15
15
class VisualEncoder (nn .Module ):
16
- def __init__ (self ,
17
- backbone ,
18
- dim ,
19
- output_dim ,
20
- backbone_type ,
21
- pooling = 'cls' ):
16
+ def __init__ (
17
+ self ,
18
+ backbone ,
19
+ dim ,
20
+ output_dim ,
21
+ backbone_type ,
22
+ pooling = 'cls' ,
23
+ ):
22
24
23
25
super ().__init__ ()
24
- self .encoder = timm .create_model (backbone ,
25
- pretrained = False ,
26
- num_classes = 0 )
26
+ self .encoder = timm .create_model (
27
+ backbone ,
28
+ pretrained = False ,
29
+ num_classes = 0 ,
30
+ )
27
31
self .backbone_type = backbone_type
28
32
self .pooling = pooling
29
33
30
34
if self .pooling == 'attention' :
31
- self .attention_pooling = nn .MultiheadAttention (dim , 1 , batch_first = True , dropout = 0.1 )
35
+ self .attention_pooling = nn .MultiheadAttention (
36
+ dim ,
37
+ 1 ,
38
+ batch_first = True ,
39
+ dropout = 0.1 ,
40
+ )
32
41
self .queries = nn .Parameter (torch .randn (1 , 197 , dim ))
33
42
34
43
self .proj = nn .Linear (dim , output_dim , bias = False )
@@ -53,7 +62,11 @@ def forward_features(self, x):
53
62
features = self .forward_features_conv (x )
54
63
55
64
if self .pooling == 'attention' :
56
- return self .attention_pooling (self .queries .expand (x .shape [0 ], - 1 , - 1 ), features , features )[0 ]
65
+ return self .attention_pooling (
66
+ self .queries .expand (x .shape [0 ], - 1 , - 1 ),
67
+ features ,
68
+ features ,
69
+ )[0 ]
57
70
58
71
return features
59
72
@@ -64,14 +77,14 @@ def forward_features_vit(self, x):
64
77
x = self .encoder .patch_embed (x )
65
78
x = self .encoder ._pos_embed (x )
66
79
x = self .encoder .norm_pre (x )
67
-
80
+
68
81
for block in self .encoder .blocks :
69
82
x = block (x )
70
-
83
+
71
84
x = self .encoder .norm (x )
72
-
73
- return x
74
-
85
+
86
+ return x
87
+
75
88
def get_embedding (self , x , project = True ):
76
89
if isinstance (x , list ):
77
90
x = x [- 1 ]
@@ -80,29 +93,31 @@ def get_embedding(self, x, project=True):
80
93
x = x [:, 0 ]
81
94
elif self .pooling == 'mean' :
82
95
x = x .mean (dim = 1 )
83
-
96
+
84
97
if project :
85
98
return self .proj (x )
86
99
87
100
return x
88
101
89
102
90
103
class TextEncoder (nn .Module ):
91
- def __init__ (self ,
92
- backbone ,
93
- backbone_type ,
94
- unimodal_n_layers ,
95
- context_dim ,
96
- dim ,
97
- output_dim ,
98
- pooling = 'cls' ,
99
- head_one_neuron = False ):
104
+ def __init__ (
105
+ self ,
106
+ backbone ,
107
+ backbone_type ,
108
+ unimodal_n_layers ,
109
+ context_dim ,
110
+ dim ,
111
+ output_dim ,
112
+ pooling = 'cls' ,
113
+ head_one_neuron = False ,
114
+ ):
100
115
101
116
super ().__init__ ()
102
117
self .backbone = TextEncoderBackbone (
103
118
backbone ,
104
119
backbone_type ,
105
- unimodal_n_layers
120
+ unimodal_n_layers ,
106
121
)
107
122
108
123
if context_dim != dim :
@@ -118,26 +133,31 @@ def __init__(self,
118
133
def forward (self , x , attention_mask , causal = False ):
119
134
features = self .forward_unimodal (x , attention_mask , causal )
120
135
return features , self .get_embedding (features , attention_mask )
121
-
136
+
122
137
def forward_unimodal (self , x , attention_mask , causal = False ):
123
- prep_attention_mask = self .prepare_attention_mask (attention_mask , causal )
138
+ prep_attention_mask = self .prepare_attention_mask (
139
+ attention_mask ,
140
+ causal ,
141
+ )
124
142
x = self .backbone .embeddings (x )
125
143
126
144
for layer in self .backbone .unimodal_encoder :
127
145
x = layer (x , prep_attention_mask )[0 ]
128
-
146
+
129
147
return x
130
148
131
149
def forward_multimodal (
132
150
self ,
133
151
x ,
134
152
attention_mask ,
135
153
context ,
136
- causal = False ):
137
-
138
- prep_attention_mask = self .prepare_attention_mask (attention_mask , causal )
154
+ causal = False ,
155
+ ):
156
+ prep_attention_mask = self .prepare_attention_mask (
157
+ attention_mask ,
158
+ causal ,
159
+ )
139
160
context = self .context_proj (context )
140
-
141
161
for layer in self .backbone .multimodal_encoder :
142
162
x , _ , _ = layer (x , prep_attention_mask , context )
143
163
@@ -147,15 +167,14 @@ def get_matching_scores(
147
167
self ,
148
168
x ,
149
169
attention_mask ,
150
- context ):
151
-
170
+ context ,
171
+ ):
152
172
embeddings = self .forward_multimodal (
153
173
x ,
154
174
attention_mask ,
155
175
context ,
156
- False
176
+ False ,
157
177
)
158
-
159
178
return self ._logit_and_norm (embeddings )
160
179
161
180
def _logit_and_norm (self , embeddings ):
@@ -170,18 +189,21 @@ def get_embedding(self, x, attention_mask, project=True):
170
189
mask_expanded = attention_mask .unsqueeze (2 )
171
190
vec_sum = (x * mask_expanded ).sum (dim = 1 )
172
191
x = vec_sum / mask_expanded .sum (dim = 1 )
192
+
173
193
elif self .pooling == 'cls' :
174
194
x = x [:, 0 ]
175
195
176
196
if project :
177
197
return self .proj (x )
178
198
179
199
return x
180
-
200
+
181
201
def prepare_attention_mask (self , mask , causal = False ):
182
202
if causal :
183
- causal_mask = torch .ones (mask .size (1 ), mask .size (1 ), device = mask .device ).tril ()
184
- mask = mask [:, None , :] * causal_mask [None , :, :] # bs x seq_len x seq_len
203
+ causal_mask = torch .ones (
204
+ mask .size (1 ), mask .size (1 ), device = mask .device ).tril ()
205
+ # bs x seq_len x seq_len
206
+ mask = mask [:, None , :] * causal_mask [None , :, :]
185
207
mask = (1 - mask ) * - 10e9
186
208
return mask [:, None ]
187
209
@@ -195,6 +217,7 @@ class TextEncoderBackbone(nn.Module):
195
217
'roberta' : (RobertaConfig , RobertaModel , RobertaAttention ),
196
218
'xlm_roberta' : (XLMRobertaConfig , XLMRobertaModel , XLMRobertaAttention )
197
219
}
220
+
198
221
def __init__ (
199
222
self ,
200
223
pretrained ,
@@ -204,19 +227,23 @@ def __init__(
204
227
super ().__init__ ()
205
228
self .unimodal_n_layers = unimodal_n_layers
206
229
207
- config_file = hf_hub_download (repo_id = pretrained , filename = 'config.json' )
230
+ config_file = hf_hub_download (
231
+ repo_id = pretrained ,
232
+ filename = 'config.json' ,
233
+ )
208
234
config_cls , model_cls , attention_layer_cls = self .type2classes [backbone_type ]
209
235
config = config_cls .from_json_file (config_file )
210
236
model = model_cls (config )
211
-
237
+
212
238
self .construct_model (model , attention_layer_cls , config )
213
239
214
240
def construct_model (
215
241
self ,
216
242
backbone ,
217
243
attention_layer_cls ,
218
- config ):
219
-
244
+ config ,
245
+ ):
246
+
220
247
self .unimodal_encoder = backbone .encoder .layer [:self .unimodal_n_layers ]
221
248
self .embeddings = backbone .embeddings
222
249
self .multimodal_encoder = []
@@ -226,27 +253,33 @@ def construct_model(
226
253
FusedTransformerLayer (
227
254
config ,
228
255
attention_layer_cls ,
229
- layer )
256
+ layer ,
230
257
)
258
+ )
231
259
232
260
self .multimodal_encoder = nn .ModuleList (self .multimodal_encoder )
233
261
262
+
234
263
class FusedTransformerLayer (nn .Module ):
235
264
def __init__ (self , config , attention_layer_cls , base_layer ):
236
265
super ().__init__ ()
237
266
238
- self .self_attention = base_layer .attention
267
+ self .self_attention = base_layer .attention
239
268
self .intermediate = base_layer .intermediate
240
269
self .output = base_layer .output
241
270
self .cross_attention = attention_layer_cls (config )
242
-
271
+
243
272
def forward (self , x , attention_mask , context ):
244
- attention_output , self_attention_probs = self .self_attention (x , attention_mask , output_attentions = True )
273
+ attention_output , self_attention_probs = self .self_attention (
274
+ x ,
275
+ attention_mask ,
276
+ output_attentions = True ,
277
+ )
245
278
attention_output , cross_attention_probs = self .cross_attention (
246
279
attention_output ,
247
280
encoder_hidden_states = context ,
248
- output_attentions = True
249
- ) # [0]
281
+ output_attentions = True ,
282
+ ) # [0]
250
283
intermediate_output = self .intermediate (attention_output )
251
284
layer_output = self .output (intermediate_output , attention_output )
252
285
return layer_output , self_attention_probs , cross_attention_probs
@@ -257,19 +290,8 @@ def __init__(self, config):
257
290
super ().__init__ ()
258
291
self .img_encoder = VisualEncoder (** config ['img_encoder' ])
259
292
self .text_encoder = TextEncoder (** config ['text_encoder' ])
260
-
261
- self .preprocess_image = Compose ([
262
- Resize (224 , interpolation = InterpolationMode .BICUBIC ),
263
- lambda x : x .convert ('RGB' ),
264
- CenterCrop (224 ),
265
- ToTensor (),
266
- Normalize (
267
- mean = (0.48145466 , 0.4578275 , 0.40821073 ),
268
- std = (0.26862954 , 0.26130258 , 0.27577711 )
269
- )
270
- ])
271
-
272
- self ._tokenizer = AutoTokenizer .from_pretrained (config ['text_encoder' ]['backbone' ])
293
+ self ._tokenizer = AutoTokenizer .from_pretrained (
294
+ config ['text_encoder' ]['backbone' ])
273
295
274
296
def encode_image (self , x : torch .Tensor , return_features = False ):
275
297
features , embs = self .img_encoder (x )
@@ -278,7 +300,7 @@ def encode_image(self, x: torch.Tensor, return_features=False):
278
300
return features , embs
279
301
280
302
return embs
281
-
303
+
282
304
def encode_text (self , x : dict , return_features = False ):
283
305
features , embs = self .text_encoder (x ['input_ids' ], x ['attention_mask' ])
284
306
@@ -293,7 +315,8 @@ def encode_multimodal(
293
315
text : dict = None ,
294
316
image_features : torch .Tensor = None ,
295
317
text_features : torch .Tensor = None ,
296
- attention_mask : torch .Tensor = None ):
318
+ attention_mask : torch .Tensor = None ,
319
+ ):
297
320
298
321
assert image is not None or image_features is not None , "Either 'image' or 'image_features' should be non None"
299
322
assert text is not None or text_features is not None , "Either 'text_data' or 'text_features' should be non None"
@@ -318,26 +341,50 @@ def encode_multimodal(
318
341
319
342
def get_matching_scores (
320
343
self ,
321
- x : torch .Tensor ):
344
+ x : torch .Tensor ,
345
+ ):
322
346
return self .text_encoder ._logit_and_norm (x )
323
347
324
-
325
348
def preprocess_text (self , x ):
326
- x = self ._tokenizer (x ,
327
- padding = 'max_length' ,
328
- truncation = True ,
329
- return_tensors = 'pt' ,
330
- pad_to_max_length = True ,
331
- max_length = 77 )
349
+ x = self ._tokenizer (
350
+ x ,
351
+ padding = 'max_length' ,
352
+ truncation = True ,
353
+ return_tensors = 'pt' ,
354
+ pad_to_max_length = True ,
355
+ max_length = 77 ,
356
+ )
332
357
if 'token_type_ids' in x :
333
358
del x ['token_type_ids' ]
334
-
359
+
335
360
return x
336
361
362
+ def preprocess_image (self , x ):
363
+ preprocessor = Compose ([
364
+ Resize (224 , interpolation = InterpolationMode .BICUBIC ),
365
+ lambda x : x .convert ('RGB' ),
366
+ CenterCrop (224 ),
367
+ ToTensor (),
368
+ Normalize (
369
+ mean = (0.48145466 , 0.4578275 , 0.40821073 ),
370
+ std = (0.26862954 , 0.26130258 , 0.27577711 )
371
+ )
372
+ ])
373
+ if isinstance (x , list ):
374
+ images = []
375
+ for image in x :
376
+ images .append (preprocessor (image ))
377
+
378
+ batch_images = torch .stack (images , dim = 0 )
379
+ return batch_images
380
+ else :
381
+ return preprocessor (x ).unsqueeze (0 )
382
+
383
+
337
384
def get_model (model_name , token = None ):
338
385
model_path = snapshot_download (
339
386
repo_id = model_name ,
340
- token = token
387
+ token = token ,
341
388
)
342
389
config_path = f'{ model_path } /config.json'
343
390
state = torch .load (f'{ model_path } /weight.pt' )
0 commit comments