16
16
"""
17
17
18
18
import collections .abc
19
- from functools import partial
20
- import math
21
19
import logging
20
+ import math
21
+ from functools import partial
22
22
23
- import numpy as np
24
23
import torch
25
- from torch import nn
26
24
import torch .nn .functional as F
25
+ from torch import nn
27
26
28
27
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
28
+ from .helpers import build_model_with_cfg , named_apply
29
29
from .layers import PatchEmbed , Mlp , DropPath , create_classifier , trunc_normal_
30
- from .layers .helpers import to_ntuple
31
- from .layers .create_conv2d import create_conv2d
32
- from .layers .pool2d_same import create_pool2d
33
- from .vision_transformer import Block
30
+ from .layers import create_conv2d , create_pool2d , to_ntuple
34
31
from .registry import register_model
35
- from .helpers import build_model_with_cfg , named_apply
36
- from .vision_transformer import resize_pos_embed
37
32
38
33
_logger = logging .getLogger (__name__ )
39
34
@@ -54,9 +49,12 @@ def _cfg(url='', **kwargs):
54
49
'nest_base' : _cfg (),
55
50
'nest_small' : _cfg (),
56
51
'nest_tiny' : _cfg (),
57
- 'jx_nest_base' : _cfg (url = 'https://www.todo-this-is-a-placeholder.com/jx_nest_base.pth' ), # TODO
58
- 'jx_nest_small' : _cfg (url = 'https://www.todo-this-is-a-placeholder.com/jx_nest_small.pth' ), # TODO
59
- 'jx_nest_tiny' : _cfg (url = 'https://www.todo-this-is-a-placeholder.com/jx_nest_tiny.pth' ), # TODO
52
+ 'jx_nest_base' : _cfg (
53
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth' ),
54
+ 'jx_nest_small' : _cfg (
55
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth' ),
56
+ 'jx_nest_tiny' : _cfg (
57
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth' ),
60
58
}
61
59
62
60
@@ -93,19 +91,18 @@ def forward(self, x):
93
91
x = (attn @ v ).permute (0 , 2 , 3 , 4 , 1 ).reshape (B , T , N , C )
94
92
x = self .proj (x )
95
93
x = self .proj_drop (x )
96
- return x # (B, T, N, C)
94
+ return x # (B, T, N, C)
97
95
98
96
99
- class TransformerLayer (Block ):
97
+ class TransformerLayer (nn . Module ):
100
98
"""
101
99
This is much like `.vision_transformer.Block` but:
102
100
- Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
103
101
- Uses modified Attention layer that handles the "block" dimension
104
102
"""
105
103
def __init__ (self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. , drop_path = 0. ,
106
104
act_layer = nn .GELU , norm_layer = nn .LayerNorm ):
107
- super ().__init__ (dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop = 0. , attn_drop = 0. , drop_path = 0. ,
108
- act_layer = nn .GELU , norm_layer = nn .LayerNorm )
105
+ super ().__init__ ()
109
106
self .norm1 = norm_layer (dim )
110
107
self .attn = Attention (dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )
111
108
self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
@@ -120,7 +117,7 @@ def forward(self, x):
120
117
return x
121
118
122
119
123
- class BlockAggregation (nn .Module ):
120
+ class ConvPool (nn .Module ):
124
121
def __init__ (self , in_channels , out_channels , norm_layer , pad_type = '' ):
125
122
super ().__init__ ()
126
123
self .conv = create_conv2d (in_channels , out_channels , kernel_size = 3 , padding = pad_type , bias = True )
@@ -137,7 +134,7 @@ def forward(self, x):
137
134
# Layer norm done over channel dim only
138
135
x = self .norm (x .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
139
136
x = self .pool (x )
140
- return x # (B, C, H//2, W//2)
137
+ return x # (B, C, H//2, W//2)
141
138
142
139
143
140
def blockify (x , block_size : int ):
@@ -152,9 +149,8 @@ def blockify(x, block_size: int):
152
149
grid_height = H // block_size
153
150
grid_width = W // block_size
154
151
x = x .reshape (B , grid_height , block_size , grid_width , block_size , C )
155
- x = x .permute (0 , 1 , 3 , 2 , 4 , 5 )
156
- x = x .reshape (B , grid_height * grid_width , - 1 , C )
157
- return x # (B, T, N, C)
152
+ x = x .transpose (2 , 3 ).reshape (B , grid_height * grid_width , - 1 , C )
153
+ return x # (B, T, N, C)
158
154
159
155
160
156
def deblockify (x , block_size : int ):
@@ -163,23 +159,30 @@ def deblockify(x, block_size: int):
163
159
x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block
164
160
block_size (int): edge length of a single square block in units of desired H, W
165
161
"""
166
- B , T , _ , C = x .shape
162
+ B , T , _ , C = x .shape
167
163
grid_size = int (math .sqrt (T ))
168
- x = x .reshape (B , grid_size , grid_size , block_size , block_size , C )
169
- x = x .permute (0 , 1 , 3 , 2 , 4 , 5 )
170
164
height = width = grid_size * block_size
171
- x = x .reshape (B , height , width , C )
172
- return x # (B, H, W, C)
173
-
165
+ x = x .reshape (B , grid_size , grid_size , block_size , block_size , C )
166
+ x = x .transpose (2 , 3 ).reshape (B , height , width , C )
167
+ return x # (B, H, W, C)
168
+
174
169
175
170
class NestLevel (nn .Module ):
176
171
""" Single hierarchical level of a Nested Transformer
177
172
"""
178
- def __init__ (self , num_blocks , block_size , seq_length , num_heads , depth , embed_dim , mlp_ratio = 4. , qkv_bias = True ,
179
- drop_rate = 0. , attn_drop_rate = 0. , drop_path_rates = [], norm_layer = None , act_layer = None ):
173
+ def __init__ (
174
+ self , num_blocks , block_size , seq_length , num_heads , depth , embed_dim , prev_embed_dim = None ,
175
+ mlp_ratio = 4. , qkv_bias = True , drop_rate = 0. , attn_drop_rate = 0. , drop_path_rates = [],
176
+ norm_layer = None , act_layer = None , pad_type = '' ):
180
177
super ().__init__ ()
181
178
self .block_size = block_size
182
179
self .pos_embed = nn .Parameter (torch .zeros (1 , num_blocks , seq_length , embed_dim ))
180
+
181
+ if prev_embed_dim is not None :
182
+ self .pool = ConvPool (prev_embed_dim , embed_dim , norm_layer = norm_layer , pad_type = pad_type )
183
+ else :
184
+ self .pool = nn .Identity ()
185
+
183
186
# Transformer encoder
184
187
if len (drop_path_rates ):
185
188
assert len (drop_path_rates ) == depth , 'Must provide as many drop path rates as there are transformer layers'
@@ -194,15 +197,14 @@ def forward(self, x):
194
197
"""
195
198
expects x as (B, C, H, W)
196
199
"""
197
- # Switch to channels last for transformer
198
- x = x .permute (0 , 2 , 3 , 1 ) # (B, H', W', C)
199
- x = blockify (x , self .block_size ) # (B, T, N, C')
200
+ x = self . pool ( x )
201
+ x = x .permute (0 , 2 , 3 , 1 ) # (B, H', W', C), switch to channels last for transformer
202
+ x = blockify (x , self .block_size ) # (B, T, N, C')
200
203
x = x + self .pos_embed
201
- x = self .transformer_encoder (x ) # (B, T, N, C')
202
- x = deblockify (x , self .block_size ) # (B, H', W', C')
204
+ x = self .transformer_encoder (x ) # (B, T, N, C')
205
+ x = deblockify (x , self .block_size ) # (B, H', W', C')
203
206
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
204
- x = x .permute (0 , 3 , 1 , 2 ) # (B, C, H', W')
205
- return x
207
+ return x .permute (0 , 3 , 1 , 2 ) # (B, C, H', W')
206
208
207
209
208
210
class Nest (nn .Module ):
@@ -213,9 +215,9 @@ class Nest(nn.Module):
213
215
"""
214
216
215
217
def __init__ (self , img_size = 224 , in_chans = 3 , patch_size = 4 , num_levels = 3 , embed_dims = (128 , 256 , 512 ),
216
- num_heads = (4 , 8 , 16 ), depths = (2 , 2 , 20 ), num_classes = 1000 , mlp_ratio = 4. , qkv_bias = True , pad_type = '' ,
217
- drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0.5 , norm_layer = None , act_layer = None , weight_init = '' ,
218
- global_pool = 'avg' ):
218
+ num_heads = (4 , 8 , 16 ), depths = (2 , 2 , 20 ), num_classes = 1000 , mlp_ratio = 4. , qkv_bias = True ,
219
+ drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0.5 , norm_layer = None , act_layer = None ,
220
+ pad_type = '' , weight_init = '' , global_pool = 'avg' ):
219
221
"""
220
222
Args:
221
223
img_size (int, tuple): input image size
@@ -233,6 +235,7 @@ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_d
233
235
drop_path_rate (float): stochastic depth rate
234
236
norm_layer: (nn.Module): normalization layer for transformer layers
235
237
act_layer: (nn.Module): activation layer in MLP of transformer layers
238
+ pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME
236
239
weight_init: (str): weight init scheme
237
240
global_pool: (str): type of pooling operation to apply to final feature map
238
241
@@ -254,6 +257,7 @@ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_d
254
257
depths = to_ntuple (num_levels )(depths )
255
258
self .num_classes = num_classes
256
259
self .num_features = embed_dims [- 1 ]
260
+ self .feature_info = []
257
261
norm_layer = norm_layer or partial (nn .LayerNorm , eps = 1e-6 )
258
262
act_layer = act_layer or nn .GELU
259
263
self .drop_rate = drop_rate
@@ -265,60 +269,54 @@ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_d
265
269
self .patch_size = patch_size
266
270
267
271
# Number of blocks at each level
268
- self .num_blocks = 4 ** ( np .arange (num_levels )[:: - 1 ] )
269
- assert (img_size // patch_size ) % np .sqrt (self .num_blocks [0 ]) == 0 , \
270
- 'First level blocks don\' t fit evenly. Check `img_size`, `patch_size`, and `num_levels`'
272
+ self .num_blocks = ( 4 ** torch .arange (num_levels )). flip ( 0 ). tolist ( )
273
+ assert (img_size // patch_size ) % math .sqrt (self .num_blocks [0 ]) == 0 , \
274
+ 'First level blocks don\' t fit evenly. Check `img_size`, `patch_size`, and `num_levels`'
271
275
272
276
# Block edge size in units of patches
273
277
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
274
278
# number of blocks along edge of image
275
- self .block_size = int ((img_size // patch_size ) // np .sqrt (self .num_blocks [0 ]))
279
+ self .block_size = int ((img_size // patch_size ) // math .sqrt (self .num_blocks [0 ]))
276
280
277
281
# Patch embedding
278
282
self .patch_embed = PatchEmbed (
279
- img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dims [0 ])
283
+ img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dims [0 ], flatten = False )
280
284
self .num_patches = self .patch_embed .num_patches
281
285
self .seq_length = self .num_patches // self .num_blocks [0 ]
282
286
283
287
# Build up each hierarchical level
284
- self .levels = nn .ModuleList ([])
285
- self .block_aggs = nn .ModuleList ([])
286
- drop_path_rates = [x .item () for x in torch .linspace (0 , drop_path_rate , sum (depths ))]
287
- for lix in range (self .num_levels ):
288
- dpr = drop_path_rates [sum (depths [:lix ]):sum (depths [:lix + 1 ])]
289
- self .levels .append (NestLevel (
290
- self .num_blocks [lix ], self .block_size , self .seq_length , num_heads [lix ], depths [lix ],
291
- embed_dims [lix ], mlp_ratio , qkv_bias , drop_rate , attn_drop_rate , dpr , norm_layer ,
292
- act_layer ))
293
- if lix < self .num_levels - 1 :
294
- self .block_aggs .append (BlockAggregation (
295
- embed_dims [lix ], embed_dims [lix + 1 ], norm_layer , pad_type = pad_type ))
296
- else :
297
- # Required for zipped iteration over levels and ls_block_agg together
298
- self .block_aggs .append (nn .Identity ())
288
+ levels = []
289
+ dp_rates = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
290
+ prev_dim = None
291
+ curr_stride = 4
292
+ for i in range (len (self .num_blocks )):
293
+ dim = embed_dims [i ]
294
+ levels .append (NestLevel (
295
+ self .num_blocks [i ], self .block_size , self .seq_length , num_heads [i ], depths [i ], dim , prev_dim ,
296
+ mlp_ratio , qkv_bias , drop_rate , attn_drop_rate , dp_rates [i ], norm_layer , act_layer , pad_type = pad_type ))
297
+ self .feature_info += [dict (num_chs = dim , reduction = curr_stride , module = f'levels.{ i } ' )]
298
+ prev_dim = dim
299
+ curr_stride *= 2
300
+ self .levels = nn .Sequential (* levels )
299
301
300
302
# Final normalization layer
301
303
self .norm = norm_layer (embed_dims [- 1 ])
302
304
303
305
# Classifier
304
- self .global_pool , self .head = create_classifier (
305
- self .num_features , self .num_classes , pool_type = global_pool )
306
+ self .global_pool , self .head = create_classifier (self .num_features , self .num_classes , pool_type = global_pool )
306
307
307
308
self .init_weights (weight_init )
308
309
309
310
def init_weights (self , mode = '' ):
310
- assert mode in ('jax' , 'jax_nlhb' , ' nlhb' , '' )
311
+ assert mode in ('nlhb' , '' )
311
312
head_bias = - math .log (self .num_classes ) if 'nlhb' in mode else 0.
312
313
for level in self .levels :
313
314
trunc_normal_ (level .pos_embed , std = .02 , a = - 2 , b = 2 )
314
- if mode .startswith ('jax' ):
315
- named_apply (partial (_init_nest_weights , head_bias = head_bias , jax_impl = True ), self )
316
- else :
317
- named_apply (_init_nest_weights , self )
315
+ named_apply (partial (_init_nest_weights , head_bias = head_bias ), self )
318
316
319
317
@torch .jit .ignore
320
318
def no_weight_decay (self ):
321
- return {' pos_embed' }
319
+ return {f'level. { i } . pos_embed' for i in range ( len ( self . levels )) }
322
320
323
321
def get_classifier (self ):
324
322
return self .head
@@ -333,13 +331,8 @@ def forward_features(self, x):
333
331
"""
334
332
B , _ , H , W = x .shape
335
333
x = self .patch_embed (x )
336
- x = x .reshape (B , H // self .patch_size , W // self .patch_size , - 1 ) # (B, H', W', C')
337
- x = x .permute (0 , 3 , 1 , 2 )
338
- # NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead
339
- for level , block_agg in zip (self .levels , self .block_aggs ):
340
- x = level (x )
341
- x = block_agg (x )
342
- # Layer norm done over channel dim only
334
+ x = self .levels (x )
335
+ # Layer norm done over channel dim only (to NHWC and back)
343
336
x = self .norm (x .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
344
337
return x
345
338
@@ -353,22 +346,19 @@ def forward(self, x):
353
346
return self .head (x )
354
347
355
348
356
- def _init_nest_weights (module : nn .Module , name : str = '' , head_bias : float = 0. , jax_impl : bool = False ):
349
+ def _init_nest_weights (module : nn .Module , name : str = '' , head_bias : float = 0. ):
357
350
""" NesT weight initialization
358
351
Can replicate Jax implementation. Otherwise follows vision_transformer.py
359
352
"""
360
353
if isinstance (module , nn .Linear ):
361
354
if name .startswith ('head' ):
362
- if jax_impl :
363
- trunc_normal_ (module .weight , std = .02 , a = - 2 , b = 2 )
364
- else :
365
- nn .init .zeros_ (module .weight )
355
+ trunc_normal_ (module .weight , std = .02 , a = - 2 , b = 2 )
366
356
nn .init .constant_ (module .bias , head_bias )
367
357
else :
368
358
trunc_normal_ (module .weight , std = .02 , a = - 2 , b = 2 )
369
359
if module .bias is not None :
370
- nn .init .zeros_ (module .bias )
371
- elif jax_impl and isinstance (module , nn .Conv2d ):
360
+ nn .init .zeros_ (module .bias )
361
+ elif isinstance (module , nn .Conv2d ):
372
362
trunc_normal_ (module .weight , std = .02 , a = - 2 , b = 2 )
373
363
if module .bias is not None :
374
364
nn .init .zeros_ (module .bias )
@@ -404,13 +394,11 @@ def checkpoint_filter_fn(state_dict, model):
404
394
405
395
406
396
def _create_nest (variant , pretrained = False , default_cfg = None , ** kwargs ):
407
- if kwargs .get ('features_only' , None ):
408
- raise RuntimeError ('features_only not implemented for Vision Transformer models.' )
409
-
410
397
default_cfg = default_cfg or default_cfgs [variant ]
411
398
model = build_model_with_cfg (
412
399
Nest , variant , pretrained ,
413
400
default_cfg = default_cfg ,
401
+ feature_cfg = dict (out_indices = (0 , 1 , 2 ), flatten_sequential = True ),
414
402
pretrained_filter_fn = checkpoint_filter_fn ,
415
403
** kwargs )
416
404
@@ -422,7 +410,7 @@ def nest_base(pretrained=False, **kwargs):
422
410
""" Nest-B @ 224x224
423
411
"""
424
412
model_kwargs = dict (
425
- embed_dims = (128 , 256 , 512 ), num_heads = (4 , 8 , 16 ), depths = (2 , 2 , 20 ), drop_path_rate = 0.5 , ** kwargs )
413
+ embed_dims = (128 , 256 , 512 ), num_heads = (4 , 8 , 16 ), depths = (2 , 2 , 20 ), ** kwargs )
426
414
model = _create_nest ('nest_base' , pretrained = pretrained , ** model_kwargs )
427
415
return model
428
416
@@ -431,8 +419,7 @@ def nest_base(pretrained=False, **kwargs):
431
419
def nest_small (pretrained = False , ** kwargs ):
432
420
""" Nest-S @ 224x224
433
421
"""
434
- model_kwargs = dict (
435
- embed_dims = (96 , 192 , 384 ), num_heads = (3 , 6 , 12 ), depths = (2 , 2 , 20 ), drop_path_rate = 0.3 , ** kwargs )
422
+ model_kwargs = dict (embed_dims = (96 , 192 , 384 ), num_heads = (3 , 6 , 12 ), depths = (2 , 2 , 20 ), ** kwargs )
436
423
model = _create_nest ('nest_small' , pretrained = pretrained , ** model_kwargs )
437
424
return model
438
425
@@ -441,8 +428,7 @@ def nest_small(pretrained=False, **kwargs):
441
428
def nest_tiny (pretrained = False , ** kwargs ):
442
429
""" Nest-T @ 224x224
443
430
"""
444
- model_kwargs = dict (
445
- embed_dims = (96 , 192 , 384 ), num_heads = (3 , 6 , 12 ), depths = (2 , 2 , 8 ), drop_path_rate = 0.2 , ** kwargs )
431
+ model_kwargs = dict (embed_dims = (96 , 192 , 384 ), num_heads = (3 , 6 , 12 ), depths = (2 , 2 , 8 ), ** kwargs )
446
432
model = _create_nest ('nest_tiny' , pretrained = pretrained , ** model_kwargs )
447
433
return model
448
434
@@ -452,9 +438,7 @@ def jx_nest_base(pretrained=False, **kwargs):
452
438
""" Nest-B @ 224x224, Pretrained weights converted from official Jax impl.
453
439
"""
454
440
kwargs ['pad_type' ] = 'same'
455
- kwargs ['weight_init' ] = 'jax'
456
- model_kwargs = dict (
457
- embed_dims = (128 , 256 , 512 ), num_heads = (4 , 8 , 16 ), depths = (2 , 2 , 20 ), drop_path_rate = 0.5 , ** kwargs )
441
+ model_kwargs = dict (embed_dims = (128 , 256 , 512 ), num_heads = (4 , 8 , 16 ), depths = (2 , 2 , 20 ), ** kwargs )
458
442
model = _create_nest ('jx_nest_base' , pretrained = pretrained , ** model_kwargs )
459
443
return model
460
444
@@ -464,9 +448,7 @@ def jx_nest_small(pretrained=False, **kwargs):
464
448
""" Nest-S @ 224x224, Pretrained weights converted from official Jax impl.
465
449
"""
466
450
kwargs ['pad_type' ] = 'same'
467
- kwargs ['weight_init' ] = 'jax'
468
- model_kwargs = dict (
469
- embed_dims = (96 , 192 , 384 ), num_heads = (3 , 6 , 12 ), depths = (2 , 2 , 20 ), drop_path_rate = 0.3 , ** kwargs )
451
+ model_kwargs = dict (embed_dims = (96 , 192 , 384 ), num_heads = (3 , 6 , 12 ), depths = (2 , 2 , 20 ), ** kwargs )
470
452
model = _create_nest ('jx_nest_small' , pretrained = pretrained , ** model_kwargs )
471
453
return model
472
454
@@ -476,8 +458,6 @@ def jx_nest_tiny(pretrained=False, **kwargs):
476
458
""" Nest-T @ 224x224, Pretrained weights converted from official Jax impl.
477
459
"""
478
460
kwargs ['pad_type' ] = 'same'
479
- kwargs ['weight_init' ] = 'jax'
480
- model_kwargs = dict (
481
- embed_dims = (96 , 192 , 384 ), num_heads = (3 , 6 , 12 ), depths = (2 , 2 , 8 ), drop_path_rate = 0.2 , ** kwargs )
461
+ model_kwargs = dict (embed_dims = (96 , 192 , 384 ), num_heads = (3 , 6 , 12 ), depths = (2 , 2 , 8 ), ** kwargs )
482
462
model = _create_nest ('jx_nest_tiny' , pretrained = pretrained , ** model_kwargs )
483
- return model
463
+ return model
0 commit comments