Skip to content

Commit 8a713b0

Browse files
committedFeb 12, 2024
Merge branch 'seefun-master'
2 parents 9589388 + 31e0dc0 commit 8a713b0

File tree

2 files changed

+745
-0
lines changed

2 files changed

+745
-0
lines changed
 

‎timm/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .gcvit import *
2626
from .ghostnet import *
2727
from .hardcorenas import *
28+
from .hgnet import *
2829
from .hrnet import *
2930
from .inception_next import *
3031
from .inception_resnet_v2 import *

‎timm/models/hgnet.py

+744
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,744 @@
1+
""" PP-HGNet (V1 & V2)
2+
3+
Reference:
4+
https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/models/ImageNet1k/PP-HGNetV2.md
5+
The Paddle Implement of PP-HGNet (https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/docs/en/models/PP-HGNet_en.md)
6+
PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py
7+
PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py
8+
"""
9+
10+
import torch
11+
import torch.nn as nn
12+
import torch.nn.functional as F
13+
14+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
15+
from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d
16+
from ._builder import build_model_with_cfg
17+
from ._registry import register_model, generate_default_cfgs
18+
19+
__all__ = ['HighPerfGpuNet']
20+
21+
22+
class LearnableAffineBlock(nn.Module):
23+
def __init__(
24+
self,
25+
scale_value=1.0,
26+
bias_value=0.0
27+
):
28+
super().__init__()
29+
self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True)
30+
self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True)
31+
32+
def forward(self, x):
33+
return self.scale * x + self.bias
34+
35+
36+
class ConvBNAct(nn.Module):
37+
def __init__(
38+
self,
39+
in_chs,
40+
out_chs,
41+
kernel_size,
42+
stride=1,
43+
groups=1,
44+
padding='',
45+
use_act=True,
46+
use_lab=False
47+
):
48+
super().__init__()
49+
self.use_act = use_act
50+
self.use_lab = use_lab
51+
self.conv = create_conv2d(
52+
in_chs,
53+
out_chs,
54+
kernel_size,
55+
stride=stride,
56+
padding=padding,
57+
groups=groups,
58+
)
59+
self.bn = nn.BatchNorm2d(out_chs)
60+
if self.use_act:
61+
self.act = nn.ReLU()
62+
else:
63+
self.act = nn.Identity()
64+
if self.use_act and self.use_lab:
65+
self.lab = LearnableAffineBlock()
66+
else:
67+
self.lab = nn.Identity()
68+
69+
def forward(self, x):
70+
x = self.conv(x)
71+
x = self.bn(x)
72+
x = self.act(x)
73+
x = self.lab(x)
74+
return x
75+
76+
77+
class LightConvBNAct(nn.Module):
78+
def __init__(
79+
self,
80+
in_chs,
81+
out_chs,
82+
kernel_size,
83+
groups=1,
84+
use_lab=False
85+
):
86+
super().__init__()
87+
self.conv1 = ConvBNAct(
88+
in_chs,
89+
out_chs,
90+
kernel_size=1,
91+
use_act=False,
92+
use_lab=use_lab,
93+
)
94+
self.conv2 = ConvBNAct(
95+
out_chs,
96+
out_chs,
97+
kernel_size=kernel_size,
98+
groups=out_chs,
99+
use_act=True,
100+
use_lab=use_lab,
101+
)
102+
103+
def forward(self, x):
104+
x = self.conv1(x)
105+
x = self.conv2(x)
106+
return x
107+
108+
109+
class EseModule(nn.Module):
110+
def __init__(self, chs):
111+
super().__init__()
112+
self.conv = nn.Conv2d(
113+
chs,
114+
chs,
115+
kernel_size=1,
116+
stride=1,
117+
padding=0,
118+
)
119+
self.sigmoid = nn.Sigmoid()
120+
121+
def forward(self, x):
122+
identity = x
123+
x = x.mean((2, 3), keepdim=True)
124+
x = self.conv(x)
125+
x = self.sigmoid(x)
126+
return torch.mul(identity, x)
127+
128+
129+
class StemV1(nn.Module):
130+
# for PP-HGNet
131+
def __init__(self, stem_chs):
132+
super().__init__()
133+
self.stem = nn.Sequential(*[
134+
ConvBNAct(
135+
stem_chs[i],
136+
stem_chs[i + 1],
137+
kernel_size=3,
138+
stride=2 if i == 0 else 1) for i in range(
139+
len(stem_chs) - 1)
140+
])
141+
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
142+
143+
def forward(self, x):
144+
x = self.stem(x)
145+
x = self.pool(x)
146+
return x
147+
148+
149+
class StemV2(nn.Module):
150+
# for PP-HGNetv2
151+
def __init__(self, in_chs, mid_chs, out_chs, use_lab=False):
152+
super().__init__()
153+
self.stem1 = ConvBNAct(
154+
in_chs,
155+
mid_chs,
156+
kernel_size=3,
157+
stride=2,
158+
use_lab=use_lab,
159+
)
160+
self.stem2a = ConvBNAct(
161+
mid_chs,
162+
mid_chs // 2,
163+
kernel_size=2,
164+
stride=1,
165+
use_lab=use_lab,
166+
)
167+
self.stem2b = ConvBNAct(
168+
mid_chs // 2,
169+
mid_chs,
170+
kernel_size=2,
171+
stride=1,
172+
use_lab=use_lab,
173+
)
174+
self.stem3 = ConvBNAct(
175+
mid_chs * 2,
176+
mid_chs,
177+
kernel_size=3,
178+
stride=2,
179+
use_lab=use_lab,
180+
)
181+
self.stem4 = ConvBNAct(
182+
mid_chs,
183+
out_chs,
184+
kernel_size=1,
185+
stride=1,
186+
use_lab=use_lab,
187+
)
188+
self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True)
189+
190+
def forward(self, x):
191+
x = self.stem1(x)
192+
x = F.pad(x, (0, 1, 0, 1))
193+
x2 = self.stem2a(x)
194+
x2 = F.pad(x2, (0, 1, 0, 1))
195+
x2 = self.stem2b(x2)
196+
x1 = self.pool(x)
197+
x = torch.cat([x1, x2], dim=1)
198+
x = self.stem3(x)
199+
x = self.stem4(x)
200+
return x
201+
202+
203+
class HighPerfGpuBlock(nn.Module):
204+
def __init__(
205+
self,
206+
in_chs,
207+
mid_chs,
208+
out_chs,
209+
layer_num,
210+
kernel_size=3,
211+
residual=False,
212+
light_block=False,
213+
use_lab=False,
214+
agg='ese',
215+
drop_path=0.,
216+
):
217+
super().__init__()
218+
self.residual = residual
219+
220+
self.layers = nn.ModuleList()
221+
for i in range(layer_num):
222+
if light_block:
223+
self.layers.append(
224+
LightConvBNAct(
225+
in_chs if i == 0 else mid_chs,
226+
mid_chs,
227+
kernel_size=kernel_size,
228+
use_lab=use_lab,
229+
)
230+
)
231+
else:
232+
self.layers.append(
233+
ConvBNAct(
234+
in_chs if i == 0 else mid_chs,
235+
mid_chs,
236+
kernel_size=kernel_size,
237+
stride=1,
238+
use_lab=use_lab,
239+
)
240+
)
241+
242+
# feature aggregation
243+
total_chs = in_chs + layer_num * mid_chs
244+
if agg == 'se':
245+
aggregation_squeeze_conv = ConvBNAct(
246+
total_chs,
247+
out_chs // 2,
248+
kernel_size=1,
249+
stride=1,
250+
use_lab=use_lab,
251+
)
252+
aggregation_excitation_conv = ConvBNAct(
253+
out_chs // 2,
254+
out_chs,
255+
kernel_size=1,
256+
stride=1,
257+
use_lab=use_lab,
258+
)
259+
self.aggregation = nn.Sequential(
260+
aggregation_squeeze_conv,
261+
aggregation_excitation_conv,
262+
)
263+
else:
264+
aggregation_conv = ConvBNAct(
265+
total_chs,
266+
out_chs,
267+
kernel_size=1,
268+
stride=1,
269+
use_lab=use_lab,
270+
)
271+
att = EseModule(out_chs)
272+
self.aggregation = nn.Sequential(
273+
aggregation_conv,
274+
att,
275+
)
276+
277+
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
278+
279+
def forward(self, x):
280+
identity = x
281+
output = [x]
282+
for layer in self.layers:
283+
x = layer(x)
284+
output.append(x)
285+
x = torch.cat(output, dim=1)
286+
x = self.aggregation(x)
287+
if self.residual:
288+
x = self.drop_path(x) + identity
289+
return x
290+
291+
292+
class HighPerfGpuStage(nn.Module):
293+
def __init__(
294+
self,
295+
in_chs,
296+
mid_chs,
297+
out_chs,
298+
block_num,
299+
layer_num,
300+
downsample=True,
301+
stride=2,
302+
light_block=False,
303+
kernel_size=3,
304+
use_lab=False,
305+
agg='ese',
306+
drop_path=0.,
307+
):
308+
super().__init__()
309+
self.downsample = downsample
310+
if downsample:
311+
self.downsample = ConvBNAct(
312+
in_chs,
313+
in_chs,
314+
kernel_size=3,
315+
stride=stride,
316+
groups=in_chs,
317+
use_act=False,
318+
use_lab=use_lab,
319+
)
320+
else:
321+
self.downsample = nn.Identity()
322+
323+
blocks_list = []
324+
for i in range(block_num):
325+
blocks_list.append(
326+
HighPerfGpuBlock(
327+
in_chs if i == 0 else out_chs,
328+
mid_chs,
329+
out_chs,
330+
layer_num,
331+
residual=False if i == 0 else True,
332+
kernel_size=kernel_size,
333+
light_block=light_block,
334+
use_lab=use_lab,
335+
agg=agg,
336+
drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
337+
)
338+
)
339+
self.blocks = nn.Sequential(*blocks_list)
340+
341+
def forward(self, x):
342+
x = self.downsample(x)
343+
x = self.blocks(x)
344+
return x
345+
346+
347+
class ClassifierHead(nn.Module):
348+
def __init__(
349+
self,
350+
num_features,
351+
num_classes,
352+
pool_type='avg',
353+
drop_rate=0.,
354+
use_last_conv=True,
355+
class_expand=2048,
356+
use_lab=False
357+
):
358+
super(ClassifierHead, self).__init__()
359+
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=False, input_fmt='NCHW')
360+
if use_last_conv:
361+
last_conv = nn.Conv2d(
362+
num_features,
363+
class_expand,
364+
kernel_size=1,
365+
stride=1,
366+
padding=0,
367+
bias=False,
368+
)
369+
act = nn.ReLU()
370+
if use_lab:
371+
lab = LearnableAffineBlock()
372+
self.last_conv = nn.Sequential(last_conv, act, lab)
373+
else:
374+
self.last_conv = nn.Sequential(last_conv, act)
375+
else:
376+
self.last_conv = nn.Indentity()
377+
378+
if drop_rate > 0:
379+
self.dropout = nn.Dropout(drop_rate)
380+
else:
381+
self.dropout = nn.Identity()
382+
383+
self.flatten = nn.Flatten()
384+
self.fc = nn.Linear(class_expand if use_last_conv else num_features, num_classes)
385+
386+
def forward(self, x, pre_logits: bool = False):
387+
x = self.global_pool(x)
388+
x = self.last_conv(x)
389+
x = self.dropout(x)
390+
x = self.flatten(x)
391+
if pre_logits:
392+
return x
393+
x = self.fc(x)
394+
return x
395+
396+
397+
class HighPerfGpuNet(nn.Module):
398+
399+
def __init__(
400+
self,
401+
cfg,
402+
in_chans=3,
403+
num_classes=1000,
404+
global_pool='avg',
405+
use_last_conv=True,
406+
class_expand=2048,
407+
drop_rate=0.,
408+
drop_path_rate=0.,
409+
use_lab=False,
410+
**kwargs,
411+
):
412+
super(HighPerfGpuNet, self).__init__()
413+
stem_type = cfg["stem_type"]
414+
stem_chs = cfg["stem_chs"]
415+
stages_cfg = [cfg["stage1"], cfg["stage2"], cfg["stage3"], cfg["stage4"]]
416+
self.num_classes = num_classes
417+
self.drop_rate = drop_rate
418+
self.use_last_conv = use_last_conv
419+
self.class_expand = class_expand
420+
self.use_lab = use_lab
421+
422+
assert stem_type in ['v1', 'v2']
423+
if stem_type == 'v2':
424+
self.stem = StemV2(
425+
in_chs=in_chans,
426+
mid_chs=stem_chs[0],
427+
out_chs=stem_chs[1],
428+
use_lab=use_lab)
429+
else:
430+
self.stem = StemV1([in_chans] + stem_chs)
431+
432+
current_stride = 4
433+
434+
stages = []
435+
self.feature_info = []
436+
block_depths = [c[3] for c in stages_cfg]
437+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(block_depths)).split(block_depths)]
438+
for i, stage_config in enumerate(stages_cfg):
439+
in_chs, mid_chs, out_chs, block_num, downsample, light_block, kernel_size, layer_num = stage_config
440+
stages += [HighPerfGpuStage(
441+
in_chs=in_chs,
442+
mid_chs=mid_chs,
443+
out_chs=out_chs,
444+
block_num=block_num,
445+
layer_num=layer_num,
446+
downsample=downsample,
447+
light_block=light_block,
448+
kernel_size=kernel_size,
449+
use_lab=use_lab,
450+
agg='ese' if stem_type == 'v1' else 'se',
451+
drop_path=dpr[i],
452+
)]
453+
self.num_features = out_chs
454+
if downsample:
455+
current_stride *= 2
456+
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
457+
self.stages = nn.Sequential(*stages)
458+
459+
if num_classes > 0:
460+
self.head = ClassifierHead(
461+
self.num_features,
462+
num_classes=num_classes,
463+
pool_type=global_pool,
464+
drop_rate=drop_rate,
465+
use_last_conv=use_last_conv,
466+
class_expand=class_expand,
467+
use_lab=use_lab
468+
)
469+
else:
470+
if global_pool == 'avg':
471+
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
472+
else:
473+
self.head = nn.Identity()
474+
475+
for n, m in self.named_modules():
476+
if isinstance(m, nn.Conv2d):
477+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
478+
elif isinstance(m, nn.BatchNorm2d):
479+
nn.init.ones_(m.weight)
480+
nn.init.zeros_(m.bias)
481+
elif isinstance(m, nn.Linear):
482+
nn.init.zeros_(m.bias)
483+
484+
@torch.jit.ignore
485+
def group_matcher(self, coarse=False):
486+
return dict(
487+
stem=r'^stem',
488+
blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)',
489+
)
490+
491+
@torch.jit.ignore
492+
def set_grad_checkpointing(self, enable=True):
493+
for s in self.stages:
494+
s.grad_checkpointing = enable
495+
496+
@torch.jit.ignore
497+
def get_classifier(self):
498+
return self.head.fc
499+
500+
def reset_classifier(self, num_classes, global_pool='avg'):
501+
self.num_classes = num_classes
502+
if num_classes > 0:
503+
self.head = ClassifierHead(
504+
self.num_features,
505+
num_classes=num_classes,
506+
pool_type=global_pool,
507+
drop_rate=self.drop_rate,
508+
use_last_conv=self.use_last_conv,
509+
class_expand=self.class_expand,
510+
use_lab=self.use_lab)
511+
else:
512+
if global_pool:
513+
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
514+
else:
515+
self.head = nn.Identity()
516+
517+
def forward_features(self, x):
518+
x = self.stem(x)
519+
return self.stages(x)
520+
521+
def forward_head(self, x, pre_logits: bool = False):
522+
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
523+
524+
def forward(self, x):
525+
x = self.forward_features(x)
526+
x = self.forward_head(x)
527+
return x
528+
529+
530+
model_cfgs = dict(
531+
# PP-HGNet
532+
hgnet_tiny={
533+
"stem_type": 'v1',
534+
"stem_chs": [48, 48, 96],
535+
# in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
536+
"stage1": [96, 96, 224, 1, False, False, 3, 5],
537+
"stage2": [224, 128, 448, 1, True, False, 3, 5],
538+
"stage3": [448, 160, 512, 2, True, False, 3, 5],
539+
"stage4": [512, 192, 768, 1, True, False, 3, 5],
540+
},
541+
hgnet_small={
542+
"stem_type": 'v1',
543+
"stem_chs": [64, 64, 128],
544+
# in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
545+
"stage1": [128, 128, 256, 1, False, False, 3, 6],
546+
"stage2": [256, 160, 512, 1, True, False, 3, 6],
547+
"stage3": [512, 192, 768, 2, True, False, 3, 6],
548+
"stage4": [768, 224, 1024, 1, True, False, 3, 6],
549+
},
550+
hgnet_base={
551+
"stem_type": 'v1',
552+
"stem_chs": [96, 96, 160],
553+
# in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
554+
"stage1": [160, 192, 320, 1, False, False, 3, 7],
555+
"stage2": [320, 224, 640, 2, True, False, 3, 7],
556+
"stage3": [640, 256, 960, 3, True, False, 3, 7],
557+
"stage4": [960, 288, 1280, 2, True, False, 3, 7],
558+
},
559+
# PP-HGNetv2
560+
hgnetv2_b0={
561+
"stem_type": 'v2',
562+
"stem_chs": [16, 16],
563+
# in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
564+
"stage1": [16, 16, 64, 1, False, False, 3, 3],
565+
"stage2": [64, 32, 256, 1, True, False, 3, 3],
566+
"stage3": [256, 64, 512, 2, True, True, 5, 3],
567+
"stage4": [512, 128, 1024, 1, True, True, 5, 3],
568+
},
569+
hgnetv2_b1={
570+
"stem_type": 'v2',
571+
"stem_chs": [24, 32],
572+
# in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
573+
"stage1": [32, 32, 64, 1, False, False, 3, 3],
574+
"stage2": [64, 48, 256, 1, True, False, 3, 3],
575+
"stage3": [256, 96, 512, 2, True, True, 5, 3],
576+
"stage4": [512, 192, 1024, 1, True, True, 5, 3],
577+
},
578+
hgnetv2_b2={
579+
"stem_type": 'v2',
580+
"stem_chs": [24, 32],
581+
# in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
582+
"stage1": [32, 32, 96, 1, False, False, 3, 4],
583+
"stage2": [96, 64, 384, 1, True, False, 3, 4],
584+
"stage3": [384, 128, 768, 3, True, True, 5, 4],
585+
"stage4": [768, 256, 1536, 1, True, True, 5, 4],
586+
},
587+
hgnetv2_b3={
588+
"stem_type": 'v2',
589+
"stem_chs": [24, 32],
590+
# in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
591+
"stage1": [32, 32, 128, 1, False, False, 3, 5],
592+
"stage2": [128, 64, 512, 1, True, False, 3, 5],
593+
"stage3": [512, 128, 1024, 3, True, True, 5, 5],
594+
"stage4": [1024, 256, 2048, 1, True, True, 5, 5],
595+
},
596+
hgnetv2_b4={
597+
"stem_type": 'v2',
598+
"stem_chs": [32, 48],
599+
# in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
600+
"stage1": [48, 48, 128, 1, False, False, 3, 6],
601+
"stage2": [128, 96, 512, 1, True, False, 3, 6],
602+
"stage3": [512, 192, 1024, 3, True, True, 5, 6],
603+
"stage4": [1024, 384, 2048, 1, True, True, 5, 6],
604+
},
605+
hgnetv2_b5={
606+
"stem_type": 'v2',
607+
"stem_chs": [32, 64],
608+
# in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
609+
"stage1": [64, 64, 128, 1, False, False, 3, 6],
610+
"stage2": [128, 128, 512, 2, True, False, 3, 6],
611+
"stage3": [512, 256, 1024, 5, True, True, 5, 6],
612+
"stage4": [1024, 512, 2048, 2, True, True, 5, 6],
613+
},
614+
hgnetv2_b6={
615+
"stem_type": 'v2',
616+
"stem_chs": [48, 96],
617+
# in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
618+
"stage1": [96, 96, 192, 2, False, False, 3, 6],
619+
"stage2": [192, 192, 512, 3, True, False, 3, 6],
620+
"stage3": [512, 384, 1024, 6, True, True, 5, 6],
621+
"stage4": [1024, 768, 2048, 3, True, True, 5, 6],
622+
},
623+
)
624+
625+
626+
def _create_hgnet(variant, pretrained=False, **kwargs):
627+
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
628+
return build_model_with_cfg(
629+
HighPerfGpuNet,
630+
variant,
631+
pretrained,
632+
model_cfg=model_cfgs[variant],
633+
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
634+
**kwargs,
635+
)
636+
637+
638+
def _cfg(url='', **kwargs):
639+
return {
640+
'url': url,
641+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
642+
'crop_pct': 0.965, 'interpolation': 'bicubic',
643+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
644+
'classifier': 'head.fc', 'first_conv': 'stem.stem1.conv',
645+
'test_crop_pct': 1.0, 'test_input_size': (3, 288, 288),
646+
**kwargs,
647+
}
648+
649+
650+
default_cfgs = generate_default_cfgs({
651+
'hgnet_tiny.paddle_in1k': _cfg(
652+
first_conv='stem.stem.0.conv',
653+
hf_hub_id='timm/'),
654+
'hgnet_tiny.ssld_in1k': _cfg(
655+
first_conv='stem.stem.0.conv',
656+
hf_hub_id='timm/'),
657+
'hgnet_small.paddle_in1k': _cfg(
658+
first_conv='stem.stem.0.conv',
659+
hf_hub_id='timm/'),
660+
'hgnet_small.ssld_in1k': _cfg(
661+
first_conv='stem.stem.0.conv',
662+
hf_hub_id='timm/'),
663+
'hgnet_base.ssld_in1k': _cfg(
664+
first_conv='stem.stem.0.conv',
665+
hf_hub_id='timm/'),
666+
'hgnetv2_b0.ssld_stage2_ft_in1k': _cfg(
667+
hf_hub_id='timm/'),
668+
'hgnetv2_b0.ssld_stage1_in22k_in1k': _cfg(
669+
hf_hub_id='timm/'),
670+
'hgnetv2_b1.ssld_stage2_ft_in1k': _cfg(
671+
hf_hub_id='timm/'),
672+
'hgnetv2_b1.ssld_stage1_in22k_in1k': _cfg(
673+
hf_hub_id='timm/'),
674+
'hgnetv2_b2.ssld_stage2_ft_in1k': _cfg(
675+
hf_hub_id='timm/'),
676+
'hgnetv2_b2.ssld_stage1_in22k_in1k': _cfg(
677+
hf_hub_id='timm/'),
678+
'hgnetv2_b3.ssld_stage2_ft_in1k': _cfg(
679+
hf_hub_id='timm/'),
680+
'hgnetv2_b3.ssld_stage1_in22k_in1k': _cfg(
681+
hf_hub_id='timm/'),
682+
'hgnetv2_b4.ssld_stage2_ft_in1k': _cfg(
683+
hf_hub_id='timm/'),
684+
'hgnetv2_b4.ssld_stage1_in22k_in1k': _cfg(
685+
hf_hub_id='timm/'),
686+
'hgnetv2_b5.ssld_stage2_ft_in1k': _cfg(
687+
hf_hub_id='timm/'),
688+
'hgnetv2_b5.ssld_stage1_in22k_in1k': _cfg(
689+
hf_hub_id='timm/'),
690+
'hgnetv2_b6.ssld_stage2_ft_in1k': _cfg(
691+
hf_hub_id='timm/'),
692+
'hgnetv2_b6.ssld_stage1_in22k_in1k': _cfg(
693+
hf_hub_id='timm/'),
694+
})
695+
696+
697+
@register_model
698+
def hgnet_tiny(pretrained=False, **kwargs) -> HighPerfGpuNet:
699+
return _create_hgnet('hgnet_tiny', pretrained=pretrained, **kwargs)
700+
701+
702+
@register_model
703+
def hgnet_small(pretrained=False, **kwargs) -> HighPerfGpuNet:
704+
return _create_hgnet('hgnet_small', pretrained=pretrained, **kwargs)
705+
706+
707+
@register_model
708+
def hgnet_base(pretrained=False, **kwargs) -> HighPerfGpuNet:
709+
return _create_hgnet('hgnet_base', pretrained=pretrained, **kwargs)
710+
711+
712+
@register_model
713+
def hgnetv2_b0(pretrained=False, **kwargs) -> HighPerfGpuNet:
714+
return _create_hgnet('hgnetv2_b0', pretrained=pretrained, use_lab=True, **kwargs)
715+
716+
717+
@register_model
718+
def hgnetv2_b1(pretrained=False, **kwargs) -> HighPerfGpuNet:
719+
return _create_hgnet('hgnetv2_b1', pretrained=pretrained, use_lab=True, **kwargs)
720+
721+
722+
@register_model
723+
def hgnetv2_b2(pretrained=False, **kwargs) -> HighPerfGpuNet:
724+
return _create_hgnet('hgnetv2_b2', pretrained=pretrained, use_lab=True, **kwargs)
725+
726+
727+
@register_model
728+
def hgnetv2_b3(pretrained=False, **kwargs) -> HighPerfGpuNet:
729+
return _create_hgnet('hgnetv2_b3', pretrained=pretrained, use_lab=True, **kwargs)
730+
731+
732+
@register_model
733+
def hgnetv2_b4(pretrained=False, **kwargs) -> HighPerfGpuNet:
734+
return _create_hgnet('hgnetv2_b4', pretrained=pretrained, **kwargs)
735+
736+
737+
@register_model
738+
def hgnetv2_b5(pretrained=False, **kwargs) -> HighPerfGpuNet:
739+
return _create_hgnet('hgnetv2_b5', pretrained=pretrained, **kwargs)
740+
741+
742+
@register_model
743+
def hgnetv2_b6(pretrained=False, **kwargs) -> HighPerfGpuNet:
744+
return _create_hgnet('hgnetv2_b6', pretrained=pretrained, **kwargs)

0 commit comments

Comments
 (0)
Please sign in to comment.