Skip to content

Commit 0f7d58e

Browse files
authored
Merge pull request #10 from unum-cloud/dev
Add: Support for batch image preprocessing Refactor: UForm PEP8 formatting
2 parents 8a452ee + b7a49a8 commit 0f7d58e

File tree

3 files changed

+125
-78
lines changed

3 files changed

+125
-78
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ from PIL import Image
6868
text = 'a small red panda in a zoo'
6969
image = Image.open('red_panda.jpg')
7070

71-
image_data = model.preprocess_image(image).unsqueeze(0)
71+
image_data = model.preprocess_image(image)
7272
text_data = model.preprocess_text(text)
7373

7474
image_embedding = model.encode_image(image_data)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "uform"
7-
version = "0.0.5"
7+
version = "0.0.6"
88
authors = [
99
{ name="Mikhail Kim", email="[email protected]" },
1010
{ name="Vladimir Orshulevich", email="[email protected]" },

src/uform.py

Lines changed: 123 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,31 @@
1313

1414

1515
class VisualEncoder(nn.Module):
16-
def __init__(self,
17-
backbone,
18-
dim,
19-
output_dim,
20-
backbone_type,
21-
pooling='cls'):
16+
def __init__(
17+
self,
18+
backbone,
19+
dim,
20+
output_dim,
21+
backbone_type,
22+
pooling='cls',
23+
):
2224

2325
super().__init__()
24-
self.encoder = timm.create_model(backbone,
25-
pretrained=False,
26-
num_classes=0)
26+
self.encoder = timm.create_model(
27+
backbone,
28+
pretrained=False,
29+
num_classes=0,
30+
)
2731
self.backbone_type = backbone_type
2832
self.pooling = pooling
2933

3034
if self.pooling == 'attention':
31-
self.attention_pooling = nn.MultiheadAttention(dim, 1, batch_first=True, dropout=0.1)
35+
self.attention_pooling = nn.MultiheadAttention(
36+
dim,
37+
1,
38+
batch_first=True,
39+
dropout=0.1,
40+
)
3241
self.queries = nn.Parameter(torch.randn(1, 197, dim))
3342

3443
self.proj = nn.Linear(dim, output_dim, bias=False)
@@ -53,7 +62,11 @@ def forward_features(self, x):
5362
features = self.forward_features_conv(x)
5463

5564
if self.pooling == 'attention':
56-
return self.attention_pooling(self.queries.expand(x.shape[0], -1, -1), features, features)[0]
65+
return self.attention_pooling(
66+
self.queries.expand(x.shape[0], -1, -1),
67+
features,
68+
features,
69+
)[0]
5770

5871
return features
5972

@@ -64,14 +77,14 @@ def forward_features_vit(self, x):
6477
x = self.encoder.patch_embed(x)
6578
x = self.encoder._pos_embed(x)
6679
x = self.encoder.norm_pre(x)
67-
80+
6881
for block in self.encoder.blocks:
6982
x = block(x)
70-
83+
7184
x = self.encoder.norm(x)
72-
73-
return x
74-
85+
86+
return x
87+
7588
def get_embedding(self, x, project=True):
7689
if isinstance(x, list):
7790
x = x[-1]
@@ -80,29 +93,31 @@ def get_embedding(self, x, project=True):
8093
x = x[:, 0]
8194
elif self.pooling == 'mean':
8295
x = x.mean(dim=1)
83-
96+
8497
if project:
8598
return self.proj(x)
8699

87100
return x
88101

89102

90103
class TextEncoder(nn.Module):
91-
def __init__(self,
92-
backbone,
93-
backbone_type,
94-
unimodal_n_layers,
95-
context_dim,
96-
dim,
97-
output_dim,
98-
pooling='cls',
99-
head_one_neuron=False):
104+
def __init__(
105+
self,
106+
backbone,
107+
backbone_type,
108+
unimodal_n_layers,
109+
context_dim,
110+
dim,
111+
output_dim,
112+
pooling='cls',
113+
head_one_neuron=False,
114+
):
100115

101116
super().__init__()
102117
self.backbone = TextEncoderBackbone(
103118
backbone,
104119
backbone_type,
105-
unimodal_n_layers
120+
unimodal_n_layers,
106121
)
107122

108123
if context_dim != dim:
@@ -118,26 +133,31 @@ def __init__(self,
118133
def forward(self, x, attention_mask, causal=False):
119134
features = self.forward_unimodal(x, attention_mask, causal)
120135
return features, self.get_embedding(features, attention_mask)
121-
136+
122137
def forward_unimodal(self, x, attention_mask, causal=False):
123-
prep_attention_mask = self.prepare_attention_mask(attention_mask, causal)
138+
prep_attention_mask = self.prepare_attention_mask(
139+
attention_mask,
140+
causal,
141+
)
124142
x = self.backbone.embeddings(x)
125143

126144
for layer in self.backbone.unimodal_encoder:
127145
x = layer(x, prep_attention_mask)[0]
128-
146+
129147
return x
130148

131149
def forward_multimodal(
132150
self,
133151
x,
134152
attention_mask,
135153
context,
136-
causal=False):
137-
138-
prep_attention_mask = self.prepare_attention_mask(attention_mask, causal)
154+
causal=False,
155+
):
156+
prep_attention_mask = self.prepare_attention_mask(
157+
attention_mask,
158+
causal,
159+
)
139160
context = self.context_proj(context)
140-
141161
for layer in self.backbone.multimodal_encoder:
142162
x, _, _ = layer(x, prep_attention_mask, context)
143163

@@ -147,15 +167,14 @@ def get_matching_scores(
147167
self,
148168
x,
149169
attention_mask,
150-
context):
151-
170+
context,
171+
):
152172
embeddings = self.forward_multimodal(
153173
x,
154174
attention_mask,
155175
context,
156-
False
176+
False,
157177
)
158-
159178
return self._logit_and_norm(embeddings)
160179

161180
def _logit_and_norm(self, embeddings):
@@ -170,18 +189,21 @@ def get_embedding(self, x, attention_mask, project=True):
170189
mask_expanded = attention_mask.unsqueeze(2)
171190
vec_sum = (x * mask_expanded).sum(dim=1)
172191
x = vec_sum / mask_expanded.sum(dim=1)
192+
173193
elif self.pooling == 'cls':
174194
x = x[:, 0]
175195

176196
if project:
177197
return self.proj(x)
178198

179199
return x
180-
200+
181201
def prepare_attention_mask(self, mask, causal=False):
182202
if causal:
183-
causal_mask = torch.ones(mask.size(1), mask.size(1), device=mask.device).tril()
184-
mask = mask[:, None, :] * causal_mask[None, :, :] # bs x seq_len x seq_len
203+
causal_mask = torch.ones(
204+
mask.size(1), mask.size(1), device=mask.device).tril()
205+
# bs x seq_len x seq_len
206+
mask = mask[:, None, :] * causal_mask[None, :, :]
185207
mask = (1 - mask) * -10e9
186208
return mask[:, None]
187209

@@ -195,6 +217,7 @@ class TextEncoderBackbone(nn.Module):
195217
'roberta': (RobertaConfig, RobertaModel, RobertaAttention),
196218
'xlm_roberta': (XLMRobertaConfig, XLMRobertaModel, XLMRobertaAttention)
197219
}
220+
198221
def __init__(
199222
self,
200223
pretrained,
@@ -204,19 +227,23 @@ def __init__(
204227
super().__init__()
205228
self.unimodal_n_layers = unimodal_n_layers
206229

207-
config_file = hf_hub_download(repo_id=pretrained, filename='config.json')
230+
config_file = hf_hub_download(
231+
repo_id=pretrained,
232+
filename='config.json',
233+
)
208234
config_cls, model_cls, attention_layer_cls = self.type2classes[backbone_type]
209235
config = config_cls.from_json_file(config_file)
210236
model = model_cls(config)
211-
237+
212238
self.construct_model(model, attention_layer_cls, config)
213239

214240
def construct_model(
215241
self,
216242
backbone,
217243
attention_layer_cls,
218-
config):
219-
244+
config,
245+
):
246+
220247
self.unimodal_encoder = backbone.encoder.layer[:self.unimodal_n_layers]
221248
self.embeddings = backbone.embeddings
222249
self.multimodal_encoder = []
@@ -226,27 +253,33 @@ def construct_model(
226253
FusedTransformerLayer(
227254
config,
228255
attention_layer_cls,
229-
layer)
256+
layer,
230257
)
258+
)
231259

232260
self.multimodal_encoder = nn.ModuleList(self.multimodal_encoder)
233261

262+
234263
class FusedTransformerLayer(nn.Module):
235264
def __init__(self, config, attention_layer_cls, base_layer):
236265
super().__init__()
237266

238-
self.self_attention = base_layer.attention
267+
self.self_attention = base_layer.attention
239268
self.intermediate = base_layer.intermediate
240269
self.output = base_layer.output
241270
self.cross_attention = attention_layer_cls(config)
242-
271+
243272
def forward(self, x, attention_mask, context):
244-
attention_output, self_attention_probs = self.self_attention(x, attention_mask, output_attentions=True)
273+
attention_output, self_attention_probs = self.self_attention(
274+
x,
275+
attention_mask,
276+
output_attentions=True,
277+
)
245278
attention_output, cross_attention_probs = self.cross_attention(
246279
attention_output,
247280
encoder_hidden_states=context,
248-
output_attentions=True
249-
) # [0]
281+
output_attentions=True,
282+
) # [0]
250283
intermediate_output = self.intermediate(attention_output)
251284
layer_output = self.output(intermediate_output, attention_output)
252285
return layer_output, self_attention_probs, cross_attention_probs
@@ -257,19 +290,8 @@ def __init__(self, config):
257290
super().__init__()
258291
self.img_encoder = VisualEncoder(**config['img_encoder'])
259292
self.text_encoder = TextEncoder(**config['text_encoder'])
260-
261-
self.preprocess_image = Compose([
262-
Resize(224, interpolation=InterpolationMode.BICUBIC),
263-
lambda x: x.convert('RGB'),
264-
CenterCrop(224),
265-
ToTensor(),
266-
Normalize(
267-
mean=(0.48145466, 0.4578275, 0.40821073),
268-
std=(0.26862954, 0.26130258, 0.27577711)
269-
)
270-
])
271-
272-
self._tokenizer = AutoTokenizer.from_pretrained(config['text_encoder']['backbone'])
293+
self._tokenizer = AutoTokenizer.from_pretrained(
294+
config['text_encoder']['backbone'])
273295

274296
def encode_image(self, x: torch.Tensor, return_features=False):
275297
features, embs = self.img_encoder(x)
@@ -278,7 +300,7 @@ def encode_image(self, x: torch.Tensor, return_features=False):
278300
return features, embs
279301

280302
return embs
281-
303+
282304
def encode_text(self, x: dict, return_features=False):
283305
features, embs = self.text_encoder(x['input_ids'], x['attention_mask'])
284306

@@ -293,7 +315,8 @@ def encode_multimodal(
293315
text: dict = None,
294316
image_features: torch.Tensor = None,
295317
text_features: torch.Tensor = None,
296-
attention_mask: torch.Tensor = None):
318+
attention_mask: torch.Tensor = None,
319+
):
297320

298321
assert image is not None or image_features is not None, "Either 'image' or 'image_features' should be non None"
299322
assert text is not None or text_features is not None, "Either 'text_data' or 'text_features' should be non None"
@@ -318,26 +341,50 @@ def encode_multimodal(
318341

319342
def get_matching_scores(
320343
self,
321-
x: torch.Tensor):
344+
x: torch.Tensor,
345+
):
322346
return self.text_encoder._logit_and_norm(x)
323347

324-
325348
def preprocess_text(self, x):
326-
x = self._tokenizer(x,
327-
padding='max_length',
328-
truncation=True,
329-
return_tensors='pt',
330-
pad_to_max_length=True,
331-
max_length=77)
349+
x = self._tokenizer(
350+
x,
351+
padding='max_length',
352+
truncation=True,
353+
return_tensors='pt',
354+
pad_to_max_length=True,
355+
max_length=77,
356+
)
332357
if 'token_type_ids' in x:
333358
del x['token_type_ids']
334-
359+
335360
return x
336361

362+
def preprocess_image(self, x):
363+
preprocessor = Compose([
364+
Resize(224, interpolation=InterpolationMode.BICUBIC),
365+
lambda x: x.convert('RGB'),
366+
CenterCrop(224),
367+
ToTensor(),
368+
Normalize(
369+
mean=(0.48145466, 0.4578275, 0.40821073),
370+
std=(0.26862954, 0.26130258, 0.27577711)
371+
)
372+
])
373+
if isinstance(x, list):
374+
images = []
375+
for image in x:
376+
images.append(preprocessor(image))
377+
378+
batch_images = torch.stack(images, dim=0)
379+
return batch_images
380+
else:
381+
return preprocessor(x).unsqueeze(0)
382+
383+
337384
def get_model(model_name, token=None):
338385
model_path = snapshot_download(
339386
repo_id=model_name,
340-
token=token
387+
token=token,
341388
)
342389
config_path = f'{model_path}/config.json'
343390
state = torch.load(f'{model_path}/weight.pt')

0 commit comments

Comments
 (0)