@@ -74,26 +74,26 @@ def _cfg(url='', **kwargs):
7474 # Fiddling with configs / defaults / still pretraining
7575 'coatnet_pico_rw_224' : _cfg (url = '' ),
7676 'coatnet_nano_rw_224' : _cfg (
77- url = '' ,
77+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth ' ,
7878 crop_pct = 0.9 ),
7979 'coatnet_0_rw_224' : _cfg (
80- url = '' ),
80+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth ' ),
8181 'coatnet_1_rw_224' : _cfg (
82- url = ''
82+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth '
8383 ),
8484 'coatnet_2_rw_224' : _cfg (url = '' ),
8585
8686 # Highly experimental configs
8787 'coatnet_bn_0_rw_224' : _cfg (
88- url = '' ,
88+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth ' ,
8989 mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
9090 crop_pct = 0.95 ),
9191 'coatnet_rmlp_nano_rw_224' : _cfg (
92- url = '' ,
92+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth ' ,
9393 crop_pct = 0.9 ),
9494 'coatnet_rmlp_0_rw_224' : _cfg (url = '' ),
9595 'coatnet_rmlp_1_rw_224' : _cfg (
96- url = '' ),
96+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth ' ),
9797 'coatnet_nano_cc_224' : _cfg (url = '' ),
9898 'coatnext_nano_rw_224' : _cfg (url = '' ),
9999
@@ -107,10 +107,12 @@ def _cfg(url='', **kwargs):
107107
108108 # Experimental configs
109109 'maxvit_pico_rw_256' : _cfg (url = '' , input_size = (3 , 256 , 256 ), pool_size = (8 , 8 )),
110- 'maxvit_nano_rw_256' : _cfg (url = '' , input_size = (3 , 256 , 256 ), pool_size = (8 , 8 )),
110+ 'maxvit_nano_rw_256' : _cfg (
111+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-3e790ce3.pth' ,
112+ input_size = (3 , 256 , 256 ), pool_size = (8 , 8 )),
111113 'maxvit_tiny_rw_224' : _cfg (url = '' ),
112114 'maxvit_tiny_rw_256' : _cfg (url = '' , input_size = (3 , 256 , 256 ), pool_size = (8 , 8 )),
113- 'maxvit_tiny_cm_256 ' : _cfg (url = '' , input_size = (3 , 256 , 256 ), pool_size = (8 , 8 )),
115+ 'maxvit_tiny_pm_256 ' : _cfg (url = '' , input_size = (3 , 256 , 256 ), pool_size = (8 , 8 )),
114116 'maxxvit_nano_rw_256' : _cfg (url = '' , input_size = (3 , 256 , 256 ), pool_size = (8 , 8 )),
115117
116118 # Trying to be like the MaxViT paper configs
@@ -131,7 +133,7 @@ class MaxxVitTransformerCfg:
131133 attn_bias : bool = True
132134 attn_drop : float = 0.
133135 proj_drop : float = 0.
134- pool_type : str = 'avg '
136+ pool_type : str = 'avg2 '
135137 rel_pos_type : str = 'bias'
136138 rel_pos_dim : int = 512 # for relative position types w/ MLP
137139 window_size : Tuple [int , int ] = (7 , 7 )
@@ -153,7 +155,7 @@ class MaxxVitConvCfg:
153155 pre_norm_act : bool = False # activation after pre-norm
154156 output_bias : bool = True # bias for shortcut + final 1x1 projection conv
155157 stride_mode : str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
156- pool_type : str = 'avg '
158+ pool_type : str = 'avg2 '
157159 downsample_pool_type : str = 'avg2'
158160 attn_early : bool = False # apply attn between conv2 and norm2, instead of after norm2
159161 attn_layer : str = 'se'
@@ -241,7 +243,7 @@ def _rw_coat_cfg(
241243
242244def _rw_max_cfg (
243245 stride_mode = 'dw' ,
244- pool_type = 'avg ' ,
246+ pool_type = 'avg2 ' ,
245247 conv_output_bias = False ,
246248 conv_attn_ratio = 1 / 16 ,
247249 conv_norm_layer = '' ,
@@ -325,7 +327,6 @@ def _next_cfg(
325327 depths = (2 , 3 , 5 , 2 ),
326328 stem_width = (32 , 64 ),
327329 ** _rw_max_cfg ( # using newer max defaults here
328- pool_type = 'avg2' ,
329330 conv_output_bias = True ,
330331 conv_attn_ratio = 0.25 ,
331332 ),
@@ -336,7 +337,6 @@ def _next_cfg(
336337 stem_width = (32 , 64 ),
337338 ** _rw_max_cfg ( # using newer max defaults here
338339 stride_mode = 'pool' ,
339- pool_type = 'avg2' ,
340340 conv_output_bias = True ,
341341 conv_attn_ratio = 0.25 ,
342342 ),
@@ -384,7 +384,6 @@ def _next_cfg(
384384 depths = (3 , 4 , 6 , 3 ),
385385 stem_width = (32 , 64 ),
386386 ** _rw_max_cfg (
387- pool_type = 'avg2' ,
388387 conv_output_bias = True ,
389388 conv_attn_ratio = 0.25 ,
390389 rel_pos_type = 'mlp' ,
@@ -487,10 +486,10 @@ def _next_cfg(
487486 stem_width = (32 , 64 ),
488487 ** _rw_max_cfg (window_size = 8 ),
489488 ),
490- maxvit_tiny_cm_256 = MaxxVitCfg (
489+ maxvit_tiny_pm_256 = MaxxVitCfg (
491490 embed_dim = (64 , 128 , 256 , 512 ),
492491 depths = (2 , 2 , 5 , 2 ),
493- block_type = ('CM ' ,) * 4 ,
492+ block_type = ('PM ' ,) * 4 ,
494493 stem_width = (32 , 64 ),
495494 ** _rw_max_cfg (window_size = 8 ),
496495 ),
@@ -663,13 +662,15 @@ def __init__(
663662 bias : bool = True ,
664663 ):
665664 super ().__init__ ()
666- assert pool_type in ('max' , 'avg' , 'avg2' )
665+ assert pool_type in ('max' , 'max2' , ' avg' , 'avg2' )
667666 if pool_type == 'max' :
668667 self .pool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
668+ elif pool_type == 'max2' :
669+ self .pool = nn .MaxPool2d (2 ) # kernel_size == stride == 2
669670 elif pool_type == 'avg' :
670671 self .pool = nn .AvgPool2d (kernel_size = 3 , stride = 2 , padding = 1 , count_include_pad = False )
671672 else :
672- self .pool = nn .AvgPool2d (2 )
673+ self .pool = nn .AvgPool2d (2 ) # kernel_size == stride == 2
673674
674675 if dim != dim_out :
675676 self .expand = nn .Conv2d (dim , dim_out , 1 , bias = bias )
@@ -1073,7 +1074,7 @@ def forward(self, x):
10731074 return x
10741075
10751076
1076- class CombinedPartitionAttention (nn .Module ):
1077+ class ParallelPartitionAttention (nn .Module ):
10771078 """ Experimental. Grid and Block partition + single FFN
10781079 NxC tensor layout.
10791080 """
@@ -1286,7 +1287,7 @@ def forward(self, x):
12861287 return x
12871288
12881289
1289- class CombinedMaxxVitBlock (nn .Module ):
1290+ class ParallelMaxxVitBlock (nn .Module ):
12901291 """
12911292 """
12921293
@@ -1309,7 +1310,7 @@ def __init__(
13091310 self .conv = nn .Sequential (* convs )
13101311 else :
13111312 self .conv = conv_cls (dim , dim_out , stride = stride , cfg = conv_cfg , drop_path = drop_path )
1312- self .attn = CombinedPartitionAttention (dim = dim_out , cfg = transformer_cfg , drop_path = drop_path )
1313+ self .attn = ParallelPartitionAttention (dim = dim_out , cfg = transformer_cfg , drop_path = drop_path )
13131314
13141315 def init_weights (self , scheme = '' ):
13151316 named_apply (partial (_init_transformer , scheme = scheme ), self .attn )
@@ -1343,7 +1344,7 @@ def __init__(
13431344 blocks = []
13441345 for i , t in enumerate (block_types ):
13451346 block_stride = stride if i == 0 else 1
1346- assert t in ('C' , 'T' , 'M' , 'CM ' )
1347+ assert t in ('C' , 'T' , 'M' , 'PM ' )
13471348 if t == 'C' :
13481349 conv_cls = ConvNeXtBlock if conv_cfg .block_type == 'convnext' else MbConvBlock
13491350 blocks += [conv_cls (
@@ -1372,8 +1373,8 @@ def __init__(
13721373 transformer_cfg = transformer_cfg ,
13731374 drop_path = drop_path [i ],
13741375 )]
1375- elif t == 'CM ' :
1376- blocks += [CombinedMaxxVitBlock (
1376+ elif t == 'PM ' :
1377+ blocks += [ParallelMaxxVitBlock (
13771378 in_chs ,
13781379 out_chs ,
13791380 stride = block_stride ,
@@ -1415,7 +1416,6 @@ def __init__(
14151416 self .norm1 = norm_act_layer (out_chs [0 ])
14161417 self .conv2 = create_conv2d (out_chs [0 ], out_chs [1 ], kernel_size , stride = 1 )
14171418
1418- @torch .jit .ignore
14191419 def init_weights (self , scheme = '' ):
14201420 named_apply (partial (_init_conv , scheme = scheme ), self )
14211421
@@ -1659,8 +1659,8 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
16591659
16601660
16611661@register_model
1662- def maxvit_tiny_cm_256 (pretrained = False , ** kwargs ):
1663- return _create_maxxvit ('maxvit_tiny_cm_256 ' , pretrained = pretrained , ** kwargs )
1662+ def maxvit_tiny_pm_256 (pretrained = False , ** kwargs ):
1663+ return _create_maxxvit ('maxvit_tiny_pm_256 ' , pretrained = pretrained , ** kwargs )
16641664
16651665
16661666@register_model
0 commit comments