Skip to content

Commit 81cd686

Browse files
committedJul 6, 2021
Move aggregation (convpool) for nest into NestLevel, cleanup and enable features_only use. Finalize weight url.
1 parent 6ae0ac6 commit 81cd686

File tree

3 files changed

+95
-112
lines changed

3 files changed

+95
-112
lines changed
 

‎README.md

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
2323

2424
## What's New
2525

26+
### July 5, 2021
27+
* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare).
28+
2629
### June 23, 2021
2730
* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6)
2831

‎convert/convert_nest_flax.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,18 @@ def convert_nest(checkpoint_path, arch):
7979
state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.bias'] = torch.tensor(
8080
flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['bias'])
8181

82-
# Block aggregations
83-
for level in range(len(depths)-1):
82+
# Block aggregations (ConvPool)
83+
for level in range(1, len(depths)):
8484
# Convs
85-
state_dict[f'block_aggs.{level}.conv.weight'] = torch.tensor(
86-
flax_dict[f'ConvPool_{level}']['Conv_0']['kernel']).permute(3, 2, 0, 1)
87-
state_dict[f'block_aggs.{level}.conv.bias'] = torch.tensor(
88-
flax_dict[f'ConvPool_{level}']['Conv_0']['bias'])
85+
state_dict[f'levels.{level}.pool.conv.weight'] = torch.tensor(
86+
flax_dict[f'ConvPool_{level-1}']['Conv_0']['kernel']).permute(3, 2, 0, 1)
87+
state_dict[f'levels.{level}.pool.conv.bias'] = torch.tensor(
88+
flax_dict[f'ConvPool_{level-1}']['Conv_0']['bias'])
8989
# Norms
90-
state_dict[f'block_aggs.{level}.norm.weight'] = torch.tensor(
91-
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['scale'])
92-
state_dict[f'block_aggs.{level}.norm.bias'] = torch.tensor(
93-
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['bias'])
90+
state_dict[f'levels.{level}.pool.norm.weight'] = torch.tensor(
91+
flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['scale'])
92+
state_dict[f'levels.{level}.pool.norm.bias'] = torch.tensor(
93+
flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['bias'])
9494

9595
# Final norm
9696
state_dict[f'norm.weight'] = torch.tensor(flax_dict['LayerNorm_0']['scale'])
@@ -105,5 +105,5 @@ def convert_nest(checkpoint_path, arch):
105105

106106
if __name__ == '__main__':
107107
variant = sys.argv[1] # base, small, or tiny
108-
state_dict = convert_nest(f'../nested-transformer/checkpoints/nest-{variant[0]}_imagenet', f'nest_{variant}')
109-
torch.save(state_dict, f'/home/alexander/.cache/torch/hub/checkpoints/jx_nest_{variant}.pth')
108+
state_dict = convert_nest(f'./nest-{variant[0]}_imagenet', f'nest_{variant}')
109+
torch.save(state_dict, f'./jx_nest_{variant}.pth')

‎timm/models/nest.py

+80-100
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,19 @@
1616
"""
1717

1818
import collections.abc
19-
from functools import partial
20-
import math
2119
import logging
20+
import math
21+
from functools import partial
2222

23-
import numpy as np
2423
import torch
25-
from torch import nn
2624
import torch.nn.functional as F
25+
from torch import nn
2726

2827
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
28+
from .helpers import build_model_with_cfg, named_apply
2929
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
30-
from .layers.helpers import to_ntuple
31-
from .layers.create_conv2d import create_conv2d
32-
from .layers.pool2d_same import create_pool2d
33-
from .vision_transformer import Block
30+
from .layers import create_conv2d, create_pool2d, to_ntuple
3431
from .registry import register_model
35-
from .helpers import build_model_with_cfg, named_apply
36-
from .vision_transformer import resize_pos_embed
3732

3833
_logger = logging.getLogger(__name__)
3934

@@ -54,9 +49,12 @@ def _cfg(url='', **kwargs):
5449
'nest_base': _cfg(),
5550
'nest_small': _cfg(),
5651
'nest_tiny': _cfg(),
57-
'jx_nest_base': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_base.pth'), # TODO
58-
'jx_nest_small': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_small.pth'), # TODO
59-
'jx_nest_tiny': _cfg(url='https://www.todo-this-is-a-placeholder.com/jx_nest_tiny.pth'), # TODO
52+
'jx_nest_base': _cfg(
53+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth'),
54+
'jx_nest_small': _cfg(
55+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth'),
56+
'jx_nest_tiny': _cfg(
57+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth'),
6058
}
6159

6260

@@ -93,19 +91,18 @@ def forward(self, x):
9391
x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
9492
x = self.proj(x)
9593
x = self.proj_drop(x)
96-
return x # (B, T, N, C)
94+
return x # (B, T, N, C)
9795

9896

99-
class TransformerLayer(Block):
97+
class TransformerLayer(nn.Module):
10098
"""
10199
This is much like `.vision_transformer.Block` but:
102100
- Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
103101
- Uses modified Attention layer that handles the "block" dimension
104102
"""
105103
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
106104
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
107-
super().__init__(dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
108-
act_layer=nn.GELU, norm_layer=nn.LayerNorm)
105+
super().__init__()
109106
self.norm1 = norm_layer(dim)
110107
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
111108
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@@ -120,7 +117,7 @@ def forward(self, x):
120117
return x
121118

122119

123-
class BlockAggregation(nn.Module):
120+
class ConvPool(nn.Module):
124121
def __init__(self, in_channels, out_channels, norm_layer, pad_type=''):
125122
super().__init__()
126123
self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True)
@@ -137,7 +134,7 @@ def forward(self, x):
137134
# Layer norm done over channel dim only
138135
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
139136
x = self.pool(x)
140-
return x # (B, C, H//2, W//2)
137+
return x # (B, C, H//2, W//2)
141138

142139

143140
def blockify(x, block_size: int):
@@ -152,9 +149,8 @@ def blockify(x, block_size: int):
152149
grid_height = H // block_size
153150
grid_width = W // block_size
154151
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
155-
x = x.permute(0, 1, 3, 2, 4, 5)
156-
x = x.reshape(B, grid_height * grid_width, -1, C)
157-
return x # (B, T, N, C)
152+
x = x.transpose(2, 3).reshape(B, grid_height * grid_width, -1, C)
153+
return x # (B, T, N, C)
158154

159155

160156
def deblockify(x, block_size: int):
@@ -163,23 +159,30 @@ def deblockify(x, block_size: int):
163159
x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block
164160
block_size (int): edge length of a single square block in units of desired H, W
165161
"""
166-
B, T, _, C= x.shape
162+
B, T, _, C = x.shape
167163
grid_size = int(math.sqrt(T))
168-
x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
169-
x = x.permute(0, 1, 3, 2, 4, 5)
170164
height = width = grid_size * block_size
171-
x = x.reshape(B, height, width, C)
172-
return x # (B, H, W, C)
173-
165+
x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
166+
x = x.transpose(2, 3).reshape(B, height, width, C)
167+
return x # (B, H, W, C)
168+
174169

175170
class NestLevel(nn.Module):
176171
""" Single hierarchical level of a Nested Transformer
177172
"""
178-
def __init__(self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, mlp_ratio=4., qkv_bias=True,
179-
drop_rate=0., attn_drop_rate=0., drop_path_rates=[], norm_layer=None, act_layer=None):
173+
def __init__(
174+
self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, prev_embed_dim=None,
175+
mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rates=[],
176+
norm_layer=None, act_layer=None, pad_type=''):
180177
super().__init__()
181178
self.block_size = block_size
182179
self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim))
180+
181+
if prev_embed_dim is not None:
182+
self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type)
183+
else:
184+
self.pool = nn.Identity()
185+
183186
# Transformer encoder
184187
if len(drop_path_rates):
185188
assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers'
@@ -194,15 +197,14 @@ def forward(self, x):
194197
"""
195198
expects x as (B, C, H, W)
196199
"""
197-
# Switch to channels last for transformer
198-
x = x.permute(0, 2, 3, 1) # (B, H', W', C)
199-
x = blockify(x, self.block_size) # (B, T, N, C')
200+
x = self.pool(x)
201+
x = x.permute(0, 2, 3, 1) # (B, H', W', C), switch to channels last for transformer
202+
x = blockify(x, self.block_size) # (B, T, N, C')
200203
x = x + self.pos_embed
201-
x = self.transformer_encoder(x) # (B, T, N, C')
202-
x = deblockify(x, self.block_size) # (B, H', W', C')
204+
x = self.transformer_encoder(x) # (B, T, N, C')
205+
x = deblockify(x, self.block_size) # (B, H', W', C')
203206
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
204-
x = x.permute(0, 3, 1, 2) # (B, C, H', W')
205-
return x
207+
return x.permute(0, 3, 1, 2) # (B, C, H', W')
206208

207209

208210
class Nest(nn.Module):
@@ -213,9 +215,9 @@ class Nest(nn.Module):
213215
"""
214216

215217
def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512),
216-
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, pad_type='',
217-
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, weight_init='',
218-
global_pool='avg'):
218+
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True,
219+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None,
220+
pad_type='', weight_init='', global_pool='avg'):
219221
"""
220222
Args:
221223
img_size (int, tuple): input image size
@@ -233,6 +235,7 @@ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_d
233235
drop_path_rate (float): stochastic depth rate
234236
norm_layer: (nn.Module): normalization layer for transformer layers
235237
act_layer: (nn.Module): activation layer in MLP of transformer layers
238+
pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME
236239
weight_init: (str): weight init scheme
237240
global_pool: (str): type of pooling operation to apply to final feature map
238241
@@ -254,6 +257,7 @@ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_d
254257
depths = to_ntuple(num_levels)(depths)
255258
self.num_classes = num_classes
256259
self.num_features = embed_dims[-1]
260+
self.feature_info = []
257261
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
258262
act_layer = act_layer or nn.GELU
259263
self.drop_rate = drop_rate
@@ -265,60 +269,54 @@ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_d
265269
self.patch_size = patch_size
266270

267271
# Number of blocks at each level
268-
self.num_blocks = 4**(np.arange(num_levels)[::-1])
269-
assert (img_size // patch_size) % np.sqrt(self.num_blocks[0]) == 0, \
270-
'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`'
272+
self.num_blocks = (4 ** torch.arange(num_levels)).flip(0).tolist()
273+
assert (img_size // patch_size) % math.sqrt(self.num_blocks[0]) == 0, \
274+
'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`'
271275

272276
# Block edge size in units of patches
273277
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
274278
# number of blocks along edge of image
275-
self.block_size = int((img_size // patch_size) // np.sqrt(self.num_blocks[0]))
279+
self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0]))
276280

277281
# Patch embedding
278282
self.patch_embed = PatchEmbed(
279-
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
283+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], flatten=False)
280284
self.num_patches = self.patch_embed.num_patches
281285
self.seq_length = self.num_patches // self.num_blocks[0]
282286

283287
# Build up each hierarchical level
284-
self.levels = nn.ModuleList([])
285-
self.block_aggs = nn.ModuleList([])
286-
drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
287-
for lix in range(self.num_levels):
288-
dpr = drop_path_rates[sum(depths[:lix]):sum(depths[:lix+1])]
289-
self.levels.append(NestLevel(
290-
self.num_blocks[lix], self.block_size, self.seq_length, num_heads[lix], depths[lix],
291-
embed_dims[lix], mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dpr, norm_layer,
292-
act_layer))
293-
if lix < self.num_levels - 1:
294-
self.block_aggs.append(BlockAggregation(
295-
embed_dims[lix], embed_dims[lix+1], norm_layer, pad_type=pad_type))
296-
else:
297-
# Required for zipped iteration over levels and ls_block_agg together
298-
self.block_aggs.append(nn.Identity())
288+
levels = []
289+
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
290+
prev_dim = None
291+
curr_stride = 4
292+
for i in range(len(self.num_blocks)):
293+
dim = embed_dims[i]
294+
levels.append(NestLevel(
295+
self.num_blocks[i], self.block_size, self.seq_length, num_heads[i], depths[i], dim, prev_dim,
296+
mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dp_rates[i], norm_layer, act_layer, pad_type=pad_type))
297+
self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')]
298+
prev_dim = dim
299+
curr_stride *= 2
300+
self.levels = nn.Sequential(*levels)
299301

300302
# Final normalization layer
301303
self.norm = norm_layer(embed_dims[-1])
302304

303305
# Classifier
304-
self.global_pool, self.head = create_classifier(
305-
self.num_features, self.num_classes, pool_type=global_pool)
306+
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
306307

307308
self.init_weights(weight_init)
308309

309310
def init_weights(self, mode=''):
310-
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
311+
assert mode in ('nlhb', '')
311312
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
312313
for level in self.levels:
313314
trunc_normal_(level.pos_embed, std=.02, a=-2, b=2)
314-
if mode.startswith('jax'):
315-
named_apply(partial(_init_nest_weights, head_bias=head_bias, jax_impl=True), self)
316-
else:
317-
named_apply(_init_nest_weights, self)
315+
named_apply(partial(_init_nest_weights, head_bias=head_bias), self)
318316

319317
@torch.jit.ignore
320318
def no_weight_decay(self):
321-
return {'pos_embed'}
319+
return {f'level.{i}.pos_embed' for i in range(len(self.levels))}
322320

323321
def get_classifier(self):
324322
return self.head
@@ -333,13 +331,8 @@ def forward_features(self, x):
333331
"""
334332
B, _, H, W = x.shape
335333
x = self.patch_embed(x)
336-
x = x.reshape(B, H//self.patch_size, W//self.patch_size, -1) # (B, H', W', C')
337-
x = x.permute(0, 3, 1, 2)
338-
# NOTE: TorchScript won't let us subscript module lists with integer variables, so we iterate instead
339-
for level, block_agg in zip(self.levels, self.block_aggs):
340-
x = level(x)
341-
x = block_agg(x)
342-
# Layer norm done over channel dim only
334+
x = self.levels(x)
335+
# Layer norm done over channel dim only (to NHWC and back)
343336
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
344337
return x
345338

@@ -353,22 +346,19 @@ def forward(self, x):
353346
return self.head(x)
354347

355348

356-
def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
349+
def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.):
357350
""" NesT weight initialization
358351
Can replicate Jax implementation. Otherwise follows vision_transformer.py
359352
"""
360353
if isinstance(module, nn.Linear):
361354
if name.startswith('head'):
362-
if jax_impl:
363-
trunc_normal_(module.weight, std=.02, a=-2, b=2)
364-
else:
365-
nn.init.zeros_(module.weight)
355+
trunc_normal_(module.weight, std=.02, a=-2, b=2)
366356
nn.init.constant_(module.bias, head_bias)
367357
else:
368358
trunc_normal_(module.weight, std=.02, a=-2, b=2)
369359
if module.bias is not None:
370-
nn.init.zeros_(module.bias)
371-
elif jax_impl and isinstance(module, nn.Conv2d):
360+
nn.init.zeros_(module.bias)
361+
elif isinstance(module, nn.Conv2d):
372362
trunc_normal_(module.weight, std=.02, a=-2, b=2)
373363
if module.bias is not None:
374364
nn.init.zeros_(module.bias)
@@ -404,13 +394,11 @@ def checkpoint_filter_fn(state_dict, model):
404394

405395

406396
def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs):
407-
if kwargs.get('features_only', None):
408-
raise RuntimeError('features_only not implemented for Vision Transformer models.')
409-
410397
default_cfg = default_cfg or default_cfgs[variant]
411398
model = build_model_with_cfg(
412399
Nest, variant, pretrained,
413400
default_cfg=default_cfg,
401+
feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True),
414402
pretrained_filter_fn=checkpoint_filter_fn,
415403
**kwargs)
416404

@@ -422,7 +410,7 @@ def nest_base(pretrained=False, **kwargs):
422410
""" Nest-B @ 224x224
423411
"""
424412
model_kwargs = dict(
425-
embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), drop_path_rate=0.5, **kwargs)
413+
embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
426414
model = _create_nest('nest_base', pretrained=pretrained, **model_kwargs)
427415
return model
428416

@@ -431,8 +419,7 @@ def nest_base(pretrained=False, **kwargs):
431419
def nest_small(pretrained=False, **kwargs):
432420
""" Nest-S @ 224x224
433421
"""
434-
model_kwargs = dict(
435-
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), drop_path_rate=0.3, **kwargs)
422+
model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs)
436423
model = _create_nest('nest_small', pretrained=pretrained, **model_kwargs)
437424
return model
438425

@@ -441,8 +428,7 @@ def nest_small(pretrained=False, **kwargs):
441428
def nest_tiny(pretrained=False, **kwargs):
442429
""" Nest-T @ 224x224
443430
"""
444-
model_kwargs = dict(
445-
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs)
431+
model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs)
446432
model = _create_nest('nest_tiny', pretrained=pretrained, **model_kwargs)
447433
return model
448434

@@ -452,9 +438,7 @@ def jx_nest_base(pretrained=False, **kwargs):
452438
""" Nest-B @ 224x224, Pretrained weights converted from official Jax impl.
453439
"""
454440
kwargs['pad_type'] = 'same'
455-
kwargs['weight_init'] = 'jax'
456-
model_kwargs = dict(
457-
embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), drop_path_rate=0.5, **kwargs)
441+
model_kwargs = dict(embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
458442
model = _create_nest('jx_nest_base', pretrained=pretrained, **model_kwargs)
459443
return model
460444

@@ -464,9 +448,7 @@ def jx_nest_small(pretrained=False, **kwargs):
464448
""" Nest-S @ 224x224, Pretrained weights converted from official Jax impl.
465449
"""
466450
kwargs['pad_type'] = 'same'
467-
kwargs['weight_init'] = 'jax'
468-
model_kwargs = dict(
469-
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), drop_path_rate=0.3, **kwargs)
451+
model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs)
470452
model = _create_nest('jx_nest_small', pretrained=pretrained, **model_kwargs)
471453
return model
472454

@@ -476,8 +458,6 @@ def jx_nest_tiny(pretrained=False, **kwargs):
476458
""" Nest-T @ 224x224, Pretrained weights converted from official Jax impl.
477459
"""
478460
kwargs['pad_type'] = 'same'
479-
kwargs['weight_init'] = 'jax'
480-
model_kwargs = dict(
481-
embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), drop_path_rate=0.2, **kwargs)
461+
model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs)
482462
model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs)
483-
return model
463+
return model

0 commit comments

Comments
 (0)
Please sign in to comment.