-
Notifications
You must be signed in to change notification settings - Fork 323
/
Copy pathcvt.py
626 lines (552 loc) · 21 KB
/
cvt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CvT in Paddle
A Paddle Implementation of CvT as described in:
"CvT: Introducing Convolutions to Vision Transformers"
- Paper Link: https://arxiv.org/abs/2103.15808
"""
from numpy import repeat
import os
import paddle
import paddle.nn as nn
from droppath import DropPath
class QuickGELU(nn.Layer):
'''
Rewrite GELU function to increase processing speed
'''
def forward(self, x: paddle.Tensor):
return x * nn.functional.sigmoid(1.702 * x)
class Mlp(nn.Layer):
""" MLP module
Impl using nn.Linear and activation is GELU, dropout is applied.
Ops: fc -> act -> dropout -> fc -> dropout
Attributes:
fc1: nn.Linear
fc2: nn.Linear
act: GELU
dropout1: dropout after fc1
dropout2: dropout after fc2
"""
def __init__(self,
embed_dim,
mlp_ratio,
act_layer=nn.GELU,
dropout=0.):
super().__init__()
w_attr_1, b_attr_1 = self._init_weights()
self.fc1 = nn.Linear(embed_dim,
int(embed_dim * mlp_ratio),
weight_attr=w_attr_1,
bias_attr=b_attr_1)
w_attr_2, b_attr_2 = self._init_weights()
self.fc2 = nn.Linear(int(embed_dim * mlp_ratio),
embed_dim,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.act = act_layer()
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def _init_weights(self):
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout1(x)
x = self.fc2(x)
x = self.dropout2(x)
return x
class ConvEmbed(nn.Layer):
""" Image to Conv Embedding
using nn.Conv2D and norm_layer to embedd the input.
Ops: conv -> norm.
Attributes:
conv: nn.Conv2D
norm: nn.LayerNorm
nn.LayerNorm handle thr input with one dim, so we should
stretch 2D input into 1D
"""
def __init__(self,
patch_size=7,
in_chans=3,
embed_dim=64,
stride=4,
padding=2,
norm_layer=None):
super().__init__()
# conv patch_size to a square,which shape is(patch_size,patch_size)
patch_size = tuple(repeat((patch_size), 2))
self.patch_size = patch_size
self.proj = nn.Conv2D(
in_chans, embed_dim,
kernel_size=patch_size,
stride=stride,
padding=padding
)
self.norm = norm_layer(embed_dim) if norm_layer else None
def forward(self, x):
x = self.proj(x)
B, C, H, W = x.shape
x = paddle.transpose(x, [0, 2, 3, 1])
x = paddle.reshape(x, [B, H*W, C])
if self.norm:
x = self.norm(x)
x = paddle.transpose(x, [0, 2, 1])
x = paddle.reshape(x, [B, C, H, W])
return x
class Attention(nn.Layer):
""" Attention module
Attention module for CvT.
using conv to calculate q,k,v
Attributes:
num_heads: number of heads
qkv: a nn.Linear for q, k, v mapping
dw_bn: nn.Conv2D -> nn.BatchNorm
avg: nn.AvgPool2D
linear: None
scales: 1 / sqrt(single_head_feature_dim)
attn_drop: dropout for attention
proj_drop: final dropout before output
out: projection of multi-head attention
"""
def __init__(self,
dim_in,
dim_out,
num_heads,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
kernel_size=3,
stride_kv=2,
stride_q=1,
padding_kv=1,
padding_q=1,
with_cls_token=True,
**kwargs
):
super().__init__()
# init to save the pararm
self.stride_kv = stride_kv
self.stride_q = stride_q
self.dim = dim_out
self.num_heads = num_heads
self.scale = dim_out ** -0.5
self.with_cls_token = with_cls_token
# calculate q,k,v with conv
self.conv_proj_q = self._build_projection(
dim_in, dim_out, kernel_size, padding_q,
stride_q,
)
self.conv_proj_k = self._build_projection(
dim_in, dim_out, kernel_size, padding_kv,
stride_kv,
)
self.conv_proj_v = self._build_projection(
dim_in, dim_out, kernel_size, padding_kv,
stride_kv,
)
# init parameters of q,k,v
w_attr_1, b_attr_1 = self._init_weights()
w_attr_2, b_attr_2 = self._init_weights()
w_attr_3, b_attr_3 = self._init_weights()
self.proj_q = nn.Linear(dim_in, dim_out, weight_attr=w_attr_1, bias_attr=b_attr_1 if qkv_bias else False)
self.proj_k = nn.Linear(dim_in, dim_out, weight_attr=w_attr_2, bias_attr=b_attr_2 if qkv_bias else False)
self.proj_v = nn.Linear(dim_in, dim_out, weight_attr=w_attr_3, bias_attr=b_attr_3 if qkv_bias else False)
# init project other parameters
self.attn_drop = nn.Dropout(attn_drop)
w_attr_4, b_attr_4 = self._init_weights()
self.proj = nn.Linear(dim_out, dim_out, weight_attr=w_attr_4, bias_attr=b_attr_4)
self.proj_drop = nn.Dropout(proj_drop)
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def _build_projection(self,
dim_in,
dim_out,
kernel_size,
padding,
stride,
):
proj = nn.Sequential(
(nn.Conv2D(
dim_in,
dim_in,
kernel_size=kernel_size,
padding=padding,
stride=stride,
bias_attr=False,
groups=dim_in
)),
(nn.BatchNorm2D(dim_in)),
)
return proj
def forward_conv(self, x, h, w):
if self.with_cls_token: # spilt token from x
cls_token, x = paddle.split(x, [1, h*w], 1)
B, L, C = x.shape # L is length of tensor
x = paddle.transpose(x, [0, 2, 1])
x = paddle.reshape(x, [B, C, h, w])
if self.conv_proj_q is not None:
q = self.conv_proj_q(x)
B, C, H, W = q.shape
q = paddle.transpose(q, [0, 2, 3, 1])
q = paddle.reshape(q, [B, H*W, C])
else:
B, C, H, W = x.shape
q = paddle.transpose(x, [0, 2, 3, 1])
q = paddle.reshape(q, [B, H*W, C])
if self.conv_proj_k is not None:
k = self.conv_proj_k(x)
B, C, H, W = k.shape
k = paddle.transpose(k, [0, 2, 3, 1])
k = paddle.reshape(k, [B, H*W, C])
else:
B, C, H, W = x.shape
k = paddle.transpose(x, [0, 2, 3, 1])
k = paddle.reshape(k, [B, H*W, C])
if self.conv_proj_v is not None:
v = self.conv_proj_v(x)
B, C, H, W = v.shape
v = paddle.transpose(v, [0, 2, 3, 1])
v = paddle.reshape(v, [B, H*W, C])
else:
# v = graph2vector(x)
B, C, H, W = x.shape
v = paddle.transpose(x, [0, 2, 3, 1])
v = paddle.reshape(v, [B, H*W, C])
if self.with_cls_token:
q = paddle.concat([cls_token, q], axis=1)
k = paddle.concat([cls_token, k], axis=1)
v = paddle.concat([cls_token, v], axis=1)
return q, k, v
def forward(self, x, h, w):
if (
self.conv_proj_q is not None
or self.conv_proj_k is not None
or self.conv_proj_v is not None
): # if not generate q,k,v with Linear param
q, k, v = self.forward_conv(x, h, w)
# now q,k,v is b (h w) c
h=self.num_heads
q=self.proj_q(q)
B, T, L = q.shape
q = paddle.reshape(q, [B, T, h, -1])
q = paddle.transpose(q, [0, 2, 1, 3])
k=self.proj_k(k)
B, T, L = k.shape
k = paddle.reshape(k, [B, T, h, -1])
k = paddle.transpose(k, [0, 2, 1, 3])
v=self.proj_v(v)
B, T, L = v.shape
v = paddle.reshape(v, [B, T, h, -1])
v = paddle.transpose(v, [0, 2, 1, 3])
# multi tensor with axis=3,then * scale,achieve the result of q*k/sqort(d_k),
attn_score = paddle.matmul(q, k, transpose_y=True) * self.scale
attn = nn.functional.softmax(attn_score, axis=-1)
attn = self.attn_drop(attn)
x = paddle.matmul(attn, v)
x = paddle.transpose(x, [0, 2, 1, 3])
x = paddle.reshape(x, [0, 0, -1])
x = self.proj(x)
x = self.proj_drop(x)
return x # b,t,(h,d)
class Block(nn.Layer):
''' Block moudule
Ops: token -> multihead attention (reshape token to a grap) ->Mlp->token
'''
def __init__(self,
dim_in,
dim_out,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
**kwargs):
super().__init__()
self.with_cls_token = kwargs['with_cls_token']
self.norm1 = norm_layer(dim_in)
self.attn = Attention(
dim_in, dim_out, num_heads, qkv_bias, attn_drop, drop,
**kwargs
)
if drop_path > 0.:
self.drop_path = DropPath(drop_path)
else:
self.drop_path = nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = Mlp(
dim_out,
mlp_ratio,
act_layer=act_layer,
dropout=drop
)
def forward(self, x, h, w):
res = x
x = self.norm1(x)
attn = self.attn(x, h, w)
x = res + self.drop_path(attn)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Layer):
""" VisionTransformer moudule
Vision Transformer with support for patch or hybrid CNN input stage
Ops:intput -> conv_embed -> depth*block -> out
Attribute:
input: raw picture
out: features,cls_token
"""
def __init__(self,
patch_size=16,
patch_stride=16,
patch_padding=0,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_layer=QuickGELU,
norm_layer=nn.LayerNorm,
init='trunc_norm',
**kwargs):
super().__init__()
# num_features for consistency with other models
self.num_features = self.embed_dim = embed_dim
self.patch_embed = ConvEmbed(
patch_size=patch_size,
in_chans=in_chans,
stride=patch_stride,
padding=patch_padding,
embed_dim=embed_dim,
norm_layer=norm_layer
)
with_cls_token = kwargs['with_cls_token']
if with_cls_token:
self.cls_token = paddle.create_parameter(
shape=[1, 1, embed_dim],
dtype='float32',
default_initializer=nn.initializer.TruncatedNormal(std=.02))
else:
self.cls_token = None
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)]
blocks = []
for j in range(depth):
blocks.append(
Block(
dim_in=embed_dim,
dim_out=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[j],
act_layer=act_layer,
norm_layer=norm_layer,
**kwargs
)
)
self.blocks = nn.LayerList(blocks)
if init == 'xavier':
self.apply(self._init_weights_xavier)
else:
self.apply(self._init_weights_trunc_normal)
def _init_weights_trunc_normal(self, m):
if isinstance(m, nn.Linear):
trun_init = nn.initializer.TruncatedNormal(std=0.02)
trun_init(m.weight)
if m.bias is not None:
zeros = nn.initializer.Constant(0.)
zeros(m.bias)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2D)):
zeros = nn.initializer.Constant(0.)
zeros(m.bias)
ones = nn.initializer.Constant(1.0)
ones(m.weight)
def _init_weights_xavier(self, m):
if isinstance(m, nn.Linear):
xavier_init = nn.initializer.XavierNormal()
xavier_init(m.weight)
if m.bias is not None:
zeros = nn.initializer.Constant(0.)
zeros(m.bias)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2D)):
zeros = nn.initializer.Constant(0.)
zeros(m.bias)
ones = nn.initializer.Constant(1)
ones(m.weight)
def forward(self, x):
x = self.patch_embed(x)
B, C, H, W = x.shape
B, C, H, W = x.shape
x = paddle.transpose(x, [0, 2, 3, 1])
x = paddle.reshape(x, [B, H*W, C])
cls_tokens = None
if self.cls_token is not None:
cls_tokens = paddle.expand(self.cls_token, [B, -1, -1])
x = paddle.concat([cls_tokens, x], axis=1)
x = self.pos_drop(x)
for i, blk in enumerate(self.blocks):
x = blk(x, H, W)
if self.cls_token is not None:
cls_tokens, x = paddle.split(x, [1, H*W], 1)
B, L, C = x.shape # L is length of tensor
x = paddle.transpose(x, [0, 2, 1])
x = paddle.reshape(x, [B, C, H, W])
return x, cls_tokens
class ConvolutionalVisionTransformer(nn.Layer):
'''CvT model
Introducing Convolutions to Vision Transformers
Args:
in_chans: int, input image channels, default: 3
num_classes: int, number of classes for classification, default: 1000
num_stage: int, numebr of stage, length of array of parameters should be given, default:3
patch_size: int[], patch size, default: [7, 3, 3]
patch_stride: int[], patch_stride ,default: [4, 2, 2]
patch_padding: int[], patch padding,default: [2, 1, 1]
embed_dim: int[], embedding dimension (patch embed out dim), default: [64, 192, 384]
depth: int[], number ot transformer blocks, default: [1, 2, 10]
num_heads: int[], number of attention heads, default: [1, 3, 6]
drop_rate: float[], Mlp layer's droppath rate for droppath layers, default: [0.0, 0.0, 0.0]
attn_drop_rate: float[], attention layer's droppath rate for droppath layers, default: [0.0, 0.0, 0.0]
drop_path_rate: float[], each block's droppath rate for droppath layers, default: [0.0, 0.0, 0.1]
with_cls_token: bool[], if image have cls_token, default: [False, False, True]
'''
def __init__(self,
in_chans=3,
num_classes=1000,
num_stage=3,
patch_size=[7, 3, 3],
patch_stride=[4, 2, 2],
patch_padding=[2, 1, 1],
embed_dim=[64, 192, 384],
depth=[1, 2, 10],
num_heads=[1, 3, 6],
drop_rate=[0.0, 0.0, 0.0],
attn_drop_rate=[0.0, 0.0, 0.0],
drop_path_rate=[0.0, 0.0, 0.1],
with_cls_token=[False, False, True],
):
super().__init__()
self.num_classes = num_classes
self.num_stages = num_stage
self.stages=nn.LayerList()
for i in range(self.num_stages):
stage = VisionTransformer(
in_chans=in_chans,
patch_size= patch_size[i],
patch_stride= patch_stride[i],
patch_padding= patch_padding[i],
embed_dim= embed_dim[i],
depth= depth[i],
num_heads= num_heads[i],
mlp_ratio= 4.0,
qkv_bias= True,
drop_rate= drop_rate[i],
attn_drop_rate= attn_drop_rate[i],
drop_path_rate= drop_path_rate[i],
with_cls_token= with_cls_token[i],
)
self.stages.append(stage)
in_chans = embed_dim[i]
dim_embed = embed_dim[-1]
self.norm = nn.LayerNorm(dim_embed)
self.cls_token = with_cls_token[-1]
# Classifier head
self.head = nn.Linear(
dim_embed, num_classes) if num_classes > 0 else nn.Identity()
trunc_init = nn.initializer.TruncatedNormal(std=0.02)
trunc_init(self.head.weight)
def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
if os.path.isfile(pretrained):
pretrained_dict = paddle.load(pretrained, map_location='cpu')
model_dict = self.state_dict()
pretrained_dict = {
k: v for k, v in pretrained_dict.items()
if k in model_dict.keys()
}
need_init_state_dict = {}
for k, v in pretrained_dict.items():
need_init = (
k.split('.')[0] in pretrained_layers
or pretrained_layers[0] is '*'
)
if need_init:
if 'pos_embed' in k and v.size() != model_dict[k].size():
size_pretrained = v.size()
size_new = model_dict[k].size()
ntok_new = size_new[1]
ntok_new -= 1
posemb_tok, posemb_grid = v[:, :1], v[0, 1:]
gs_old = int(paddle.sqrt(len(posemb_grid)))
gs_new = int(paddle.sqrt(ntok_new))
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
posemb_grid = paddle.ndimage.zoom(
posemb_grid, zoom, order=1
)
posemb_grid = posemb_grid.reshape(1, gs_new ** 2, -1)
v = paddle.to_tensor(
paddle.concat([posemb_tok, posemb_grid], axis=1)
)
need_init_state_dict[k] = v
self.load_state_dict(need_init_state_dict, strict=False)
def forward_features(self, x):
for i in range(self.num_stages):
x, cls_tokens = self.stages[i](x)
if self.cls_token:
x = self.norm(cls_tokens)
x = paddle.squeeze(x)
else:
#'b c h w -> b (h w) c'
B, C, H, W = x.shape
x = paddle.transpose(x, [0, 2, 3, 1])
x = paddle.reshape(x, [B, H*W, C])
x = self.norm(x)
x = paddle.mean(x, axis=1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def build_cvt(config):
model = ConvolutionalVisionTransformer(
in_chans=3,
num_classes=config.MODEL.NUM_CLASSES,
num_stage=config.MODEL.NUM_STAGES,
patch_size=config.MODEL.PATCH_SIZE,
patch_stride=config.MODEL.PATCH_STRIDE,
patch_padding=config.MODEL.PATCH_PADDING,
embed_dim=config.MODEL.DIM_EMBED,
depth=config.MODEL.DEPTH,
num_heads=config.MODEL.NUM_HEADS,
drop_rate=config.MODEL.DROP_RATE,
attn_drop_rate=config.MODEL.ATTN_DROP_RATE,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
with_cls_token=config.MODEL.CLS_TOKEN
)
return model