|
| 1 | +""" PP-HGNet (V1 & V2) |
| 2 | +
|
| 3 | +Reference: |
| 4 | +https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/models/ImageNet1k/PP-HGNetV2.md |
| 5 | +The Paddle Implement of PP-HGNet (https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/docs/en/models/PP-HGNet_en.md) |
| 6 | +PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py |
| 7 | +PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py |
| 8 | +""" |
| 9 | + |
| 10 | +import torch |
| 11 | +import torch.nn as nn |
| 12 | +import torch.nn.functional as F |
| 13 | + |
| 14 | +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| 15 | +from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d |
| 16 | +from ._builder import build_model_with_cfg |
| 17 | +from ._registry import register_model, generate_default_cfgs |
| 18 | + |
| 19 | +__all__ = ['HighPerfGpuNet'] |
| 20 | + |
| 21 | + |
| 22 | +class LearnableAffineBlock(nn.Module): |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + scale_value=1.0, |
| 26 | + bias_value=0.0 |
| 27 | + ): |
| 28 | + super().__init__() |
| 29 | + self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True) |
| 30 | + self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True) |
| 31 | + |
| 32 | + def forward(self, x): |
| 33 | + return self.scale * x + self.bias |
| 34 | + |
| 35 | + |
| 36 | +class ConvBNAct(nn.Module): |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + in_chs, |
| 40 | + out_chs, |
| 41 | + kernel_size, |
| 42 | + stride=1, |
| 43 | + groups=1, |
| 44 | + padding='', |
| 45 | + use_act=True, |
| 46 | + use_lab=False |
| 47 | + ): |
| 48 | + super().__init__() |
| 49 | + self.use_act = use_act |
| 50 | + self.use_lab = use_lab |
| 51 | + self.conv = create_conv2d( |
| 52 | + in_chs, |
| 53 | + out_chs, |
| 54 | + kernel_size, |
| 55 | + stride=stride, |
| 56 | + padding=padding, |
| 57 | + groups=groups, |
| 58 | + ) |
| 59 | + self.bn = nn.BatchNorm2d(out_chs) |
| 60 | + if self.use_act: |
| 61 | + self.act = nn.ReLU() |
| 62 | + else: |
| 63 | + self.act = nn.Identity() |
| 64 | + if self.use_act and self.use_lab: |
| 65 | + self.lab = LearnableAffineBlock() |
| 66 | + else: |
| 67 | + self.lab = nn.Identity() |
| 68 | + |
| 69 | + def forward(self, x): |
| 70 | + x = self.conv(x) |
| 71 | + x = self.bn(x) |
| 72 | + x = self.act(x) |
| 73 | + x = self.lab(x) |
| 74 | + return x |
| 75 | + |
| 76 | + |
| 77 | +class LightConvBNAct(nn.Module): |
| 78 | + def __init__( |
| 79 | + self, |
| 80 | + in_chs, |
| 81 | + out_chs, |
| 82 | + kernel_size, |
| 83 | + groups=1, |
| 84 | + use_lab=False |
| 85 | + ): |
| 86 | + super().__init__() |
| 87 | + self.conv1 = ConvBNAct( |
| 88 | + in_chs, |
| 89 | + out_chs, |
| 90 | + kernel_size=1, |
| 91 | + use_act=False, |
| 92 | + use_lab=use_lab, |
| 93 | + ) |
| 94 | + self.conv2 = ConvBNAct( |
| 95 | + out_chs, |
| 96 | + out_chs, |
| 97 | + kernel_size=kernel_size, |
| 98 | + groups=out_chs, |
| 99 | + use_act=True, |
| 100 | + use_lab=use_lab, |
| 101 | + ) |
| 102 | + |
| 103 | + def forward(self, x): |
| 104 | + x = self.conv1(x) |
| 105 | + x = self.conv2(x) |
| 106 | + return x |
| 107 | + |
| 108 | + |
| 109 | +class EseModule(nn.Module): |
| 110 | + def __init__(self, chs): |
| 111 | + super().__init__() |
| 112 | + self.conv = nn.Conv2d( |
| 113 | + chs, |
| 114 | + chs, |
| 115 | + kernel_size=1, |
| 116 | + stride=1, |
| 117 | + padding=0, |
| 118 | + ) |
| 119 | + self.sigmoid = nn.Sigmoid() |
| 120 | + |
| 121 | + def forward(self, x): |
| 122 | + identity = x |
| 123 | + x = x.mean((2, 3), keepdim=True) |
| 124 | + x = self.conv(x) |
| 125 | + x = self.sigmoid(x) |
| 126 | + return torch.mul(identity, x) |
| 127 | + |
| 128 | + |
| 129 | +class StemV1(nn.Module): |
| 130 | + # for PP-HGNet |
| 131 | + def __init__(self, stem_chs): |
| 132 | + super().__init__() |
| 133 | + self.stem = nn.Sequential(*[ |
| 134 | + ConvBNAct( |
| 135 | + stem_chs[i], |
| 136 | + stem_chs[i + 1], |
| 137 | + kernel_size=3, |
| 138 | + stride=2 if i == 0 else 1) for i in range( |
| 139 | + len(stem_chs) - 1) |
| 140 | + ]) |
| 141 | + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
| 142 | + |
| 143 | + def forward(self, x): |
| 144 | + x = self.stem(x) |
| 145 | + x = self.pool(x) |
| 146 | + return x |
| 147 | + |
| 148 | + |
| 149 | +class StemV2(nn.Module): |
| 150 | + # for PP-HGNetv2 |
| 151 | + def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): |
| 152 | + super().__init__() |
| 153 | + self.stem1 = ConvBNAct( |
| 154 | + in_chs, |
| 155 | + mid_chs, |
| 156 | + kernel_size=3, |
| 157 | + stride=2, |
| 158 | + use_lab=use_lab, |
| 159 | + ) |
| 160 | + self.stem2a = ConvBNAct( |
| 161 | + mid_chs, |
| 162 | + mid_chs // 2, |
| 163 | + kernel_size=2, |
| 164 | + stride=1, |
| 165 | + use_lab=use_lab, |
| 166 | + ) |
| 167 | + self.stem2b = ConvBNAct( |
| 168 | + mid_chs // 2, |
| 169 | + mid_chs, |
| 170 | + kernel_size=2, |
| 171 | + stride=1, |
| 172 | + use_lab=use_lab, |
| 173 | + ) |
| 174 | + self.stem3 = ConvBNAct( |
| 175 | + mid_chs * 2, |
| 176 | + mid_chs, |
| 177 | + kernel_size=3, |
| 178 | + stride=2, |
| 179 | + use_lab=use_lab, |
| 180 | + ) |
| 181 | + self.stem4 = ConvBNAct( |
| 182 | + mid_chs, |
| 183 | + out_chs, |
| 184 | + kernel_size=1, |
| 185 | + stride=1, |
| 186 | + use_lab=use_lab, |
| 187 | + ) |
| 188 | + self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True) |
| 189 | + |
| 190 | + def forward(self, x): |
| 191 | + x = self.stem1(x) |
| 192 | + x = F.pad(x, (0, 1, 0, 1)) |
| 193 | + x2 = self.stem2a(x) |
| 194 | + x2 = F.pad(x2, (0, 1, 0, 1)) |
| 195 | + x2 = self.stem2b(x2) |
| 196 | + x1 = self.pool(x) |
| 197 | + x = torch.cat([x1, x2], dim=1) |
| 198 | + x = self.stem3(x) |
| 199 | + x = self.stem4(x) |
| 200 | + return x |
| 201 | + |
| 202 | + |
| 203 | +class HighPerfGpuBlock(nn.Module): |
| 204 | + def __init__( |
| 205 | + self, |
| 206 | + in_chs, |
| 207 | + mid_chs, |
| 208 | + out_chs, |
| 209 | + layer_num, |
| 210 | + kernel_size=3, |
| 211 | + residual=False, |
| 212 | + light_block=False, |
| 213 | + use_lab=False, |
| 214 | + agg='ese', |
| 215 | + drop_path=0., |
| 216 | + ): |
| 217 | + super().__init__() |
| 218 | + self.residual = residual |
| 219 | + |
| 220 | + self.layers = nn.ModuleList() |
| 221 | + for i in range(layer_num): |
| 222 | + if light_block: |
| 223 | + self.layers.append( |
| 224 | + LightConvBNAct( |
| 225 | + in_chs if i == 0 else mid_chs, |
| 226 | + mid_chs, |
| 227 | + kernel_size=kernel_size, |
| 228 | + use_lab=use_lab, |
| 229 | + ) |
| 230 | + ) |
| 231 | + else: |
| 232 | + self.layers.append( |
| 233 | + ConvBNAct( |
| 234 | + in_chs if i == 0 else mid_chs, |
| 235 | + mid_chs, |
| 236 | + kernel_size=kernel_size, |
| 237 | + stride=1, |
| 238 | + use_lab=use_lab, |
| 239 | + ) |
| 240 | + ) |
| 241 | + |
| 242 | + # feature aggregation |
| 243 | + total_chs = in_chs + layer_num * mid_chs |
| 244 | + if agg == 'se': |
| 245 | + aggregation_squeeze_conv = ConvBNAct( |
| 246 | + total_chs, |
| 247 | + out_chs // 2, |
| 248 | + kernel_size=1, |
| 249 | + stride=1, |
| 250 | + use_lab=use_lab, |
| 251 | + ) |
| 252 | + aggregation_excitation_conv = ConvBNAct( |
| 253 | + out_chs // 2, |
| 254 | + out_chs, |
| 255 | + kernel_size=1, |
| 256 | + stride=1, |
| 257 | + use_lab=use_lab, |
| 258 | + ) |
| 259 | + self.aggregation = nn.Sequential( |
| 260 | + aggregation_squeeze_conv, |
| 261 | + aggregation_excitation_conv, |
| 262 | + ) |
| 263 | + else: |
| 264 | + aggregation_conv = ConvBNAct( |
| 265 | + total_chs, |
| 266 | + out_chs, |
| 267 | + kernel_size=1, |
| 268 | + stride=1, |
| 269 | + use_lab=use_lab, |
| 270 | + ) |
| 271 | + att = EseModule(out_chs) |
| 272 | + self.aggregation = nn.Sequential( |
| 273 | + aggregation_conv, |
| 274 | + att, |
| 275 | + ) |
| 276 | + |
| 277 | + self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() |
| 278 | + |
| 279 | + def forward(self, x): |
| 280 | + identity = x |
| 281 | + output = [x] |
| 282 | + for layer in self.layers: |
| 283 | + x = layer(x) |
| 284 | + output.append(x) |
| 285 | + x = torch.cat(output, dim=1) |
| 286 | + x = self.aggregation(x) |
| 287 | + if self.residual: |
| 288 | + x = self.drop_path(x) + identity |
| 289 | + return x |
| 290 | + |
| 291 | + |
| 292 | +class HighPerfGpuStage(nn.Module): |
| 293 | + def __init__( |
| 294 | + self, |
| 295 | + in_chs, |
| 296 | + mid_chs, |
| 297 | + out_chs, |
| 298 | + block_num, |
| 299 | + layer_num, |
| 300 | + downsample=True, |
| 301 | + stride=2, |
| 302 | + light_block=False, |
| 303 | + kernel_size=3, |
| 304 | + use_lab=False, |
| 305 | + agg='ese', |
| 306 | + drop_path=0., |
| 307 | + ): |
| 308 | + super().__init__() |
| 309 | + self.downsample = downsample |
| 310 | + if downsample: |
| 311 | + self.downsample = ConvBNAct( |
| 312 | + in_chs, |
| 313 | + in_chs, |
| 314 | + kernel_size=3, |
| 315 | + stride=stride, |
| 316 | + groups=in_chs, |
| 317 | + use_act=False, |
| 318 | + use_lab=use_lab, |
| 319 | + ) |
| 320 | + else: |
| 321 | + self.downsample = nn.Identity() |
| 322 | + |
| 323 | + blocks_list = [] |
| 324 | + for i in range(block_num): |
| 325 | + blocks_list.append( |
| 326 | + HighPerfGpuBlock( |
| 327 | + in_chs if i == 0 else out_chs, |
| 328 | + mid_chs, |
| 329 | + out_chs, |
| 330 | + layer_num, |
| 331 | + residual=False if i == 0 else True, |
| 332 | + kernel_size=kernel_size, |
| 333 | + light_block=light_block, |
| 334 | + use_lab=use_lab, |
| 335 | + agg=agg, |
| 336 | + drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, |
| 337 | + ) |
| 338 | + ) |
| 339 | + self.blocks = nn.Sequential(*blocks_list) |
| 340 | + |
| 341 | + def forward(self, x): |
| 342 | + x = self.downsample(x) |
| 343 | + x = self.blocks(x) |
| 344 | + return x |
| 345 | + |
| 346 | + |
| 347 | +class ClassifierHead(nn.Module): |
| 348 | + def __init__( |
| 349 | + self, |
| 350 | + num_features, |
| 351 | + num_classes, |
| 352 | + pool_type='avg', |
| 353 | + drop_rate=0., |
| 354 | + use_last_conv=True, |
| 355 | + class_expand=2048, |
| 356 | + use_lab=False |
| 357 | + ): |
| 358 | + super(ClassifierHead, self).__init__() |
| 359 | + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=False, input_fmt='NCHW') |
| 360 | + if use_last_conv: |
| 361 | + last_conv = nn.Conv2d( |
| 362 | + num_features, |
| 363 | + class_expand, |
| 364 | + kernel_size=1, |
| 365 | + stride=1, |
| 366 | + padding=0, |
| 367 | + bias=False, |
| 368 | + ) |
| 369 | + act = nn.ReLU() |
| 370 | + if use_lab: |
| 371 | + lab = LearnableAffineBlock() |
| 372 | + self.last_conv = nn.Sequential(last_conv, act, lab) |
| 373 | + else: |
| 374 | + self.last_conv = nn.Sequential(last_conv, act) |
| 375 | + else: |
| 376 | + self.last_conv = nn.Indentity() |
| 377 | + |
| 378 | + if drop_rate > 0: |
| 379 | + self.dropout = nn.Dropout(drop_rate) |
| 380 | + else: |
| 381 | + self.dropout = nn.Identity() |
| 382 | + |
| 383 | + self.flatten = nn.Flatten() |
| 384 | + self.fc = nn.Linear(class_expand if use_last_conv else num_features, num_classes) |
| 385 | + |
| 386 | + def forward(self, x, pre_logits: bool = False): |
| 387 | + x = self.global_pool(x) |
| 388 | + x = self.last_conv(x) |
| 389 | + x = self.dropout(x) |
| 390 | + x = self.flatten(x) |
| 391 | + if pre_logits: |
| 392 | + return x |
| 393 | + x = self.fc(x) |
| 394 | + return x |
| 395 | + |
| 396 | + |
| 397 | +class HighPerfGpuNet(nn.Module): |
| 398 | + |
| 399 | + def __init__( |
| 400 | + self, |
| 401 | + cfg, |
| 402 | + in_chans=3, |
| 403 | + num_classes=1000, |
| 404 | + global_pool='avg', |
| 405 | + use_last_conv=True, |
| 406 | + class_expand=2048, |
| 407 | + drop_rate=0., |
| 408 | + drop_path_rate=0., |
| 409 | + use_lab=False, |
| 410 | + **kwargs, |
| 411 | + ): |
| 412 | + super(HighPerfGpuNet, self).__init__() |
| 413 | + stem_type = cfg["stem_type"] |
| 414 | + stem_chs = cfg["stem_chs"] |
| 415 | + stages_cfg = [cfg["stage1"], cfg["stage2"], cfg["stage3"], cfg["stage4"]] |
| 416 | + self.num_classes = num_classes |
| 417 | + self.drop_rate = drop_rate |
| 418 | + self.use_last_conv = use_last_conv |
| 419 | + self.class_expand = class_expand |
| 420 | + self.use_lab = use_lab |
| 421 | + |
| 422 | + assert stem_type in ['v1', 'v2'] |
| 423 | + if stem_type == 'v2': |
| 424 | + self.stem = StemV2( |
| 425 | + in_chs=in_chans, |
| 426 | + mid_chs=stem_chs[0], |
| 427 | + out_chs=stem_chs[1], |
| 428 | + use_lab=use_lab) |
| 429 | + else: |
| 430 | + self.stem = StemV1([in_chans] + stem_chs) |
| 431 | + |
| 432 | + current_stride = 4 |
| 433 | + |
| 434 | + stages = [] |
| 435 | + self.feature_info = [] |
| 436 | + block_depths = [c[3] for c in stages_cfg] |
| 437 | + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(block_depths)).split(block_depths)] |
| 438 | + for i, stage_config in enumerate(stages_cfg): |
| 439 | + in_chs, mid_chs, out_chs, block_num, downsample, light_block, kernel_size, layer_num = stage_config |
| 440 | + stages += [HighPerfGpuStage( |
| 441 | + in_chs=in_chs, |
| 442 | + mid_chs=mid_chs, |
| 443 | + out_chs=out_chs, |
| 444 | + block_num=block_num, |
| 445 | + layer_num=layer_num, |
| 446 | + downsample=downsample, |
| 447 | + light_block=light_block, |
| 448 | + kernel_size=kernel_size, |
| 449 | + use_lab=use_lab, |
| 450 | + agg='ese' if stem_type == 'v1' else 'se', |
| 451 | + drop_path=dpr[i], |
| 452 | + )] |
| 453 | + self.num_features = out_chs |
| 454 | + if downsample: |
| 455 | + current_stride *= 2 |
| 456 | + self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] |
| 457 | + self.stages = nn.Sequential(*stages) |
| 458 | + |
| 459 | + if num_classes > 0: |
| 460 | + self.head = ClassifierHead( |
| 461 | + self.num_features, |
| 462 | + num_classes=num_classes, |
| 463 | + pool_type=global_pool, |
| 464 | + drop_rate=drop_rate, |
| 465 | + use_last_conv=use_last_conv, |
| 466 | + class_expand=class_expand, |
| 467 | + use_lab=use_lab |
| 468 | + ) |
| 469 | + else: |
| 470 | + if global_pool == 'avg': |
| 471 | + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) |
| 472 | + else: |
| 473 | + self.head = nn.Identity() |
| 474 | + |
| 475 | + for n, m in self.named_modules(): |
| 476 | + if isinstance(m, nn.Conv2d): |
| 477 | + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| 478 | + elif isinstance(m, nn.BatchNorm2d): |
| 479 | + nn.init.ones_(m.weight) |
| 480 | + nn.init.zeros_(m.bias) |
| 481 | + elif isinstance(m, nn.Linear): |
| 482 | + nn.init.zeros_(m.bias) |
| 483 | + |
| 484 | + @torch.jit.ignore |
| 485 | + def group_matcher(self, coarse=False): |
| 486 | + return dict( |
| 487 | + stem=r'^stem', |
| 488 | + blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)', |
| 489 | + ) |
| 490 | + |
| 491 | + @torch.jit.ignore |
| 492 | + def set_grad_checkpointing(self, enable=True): |
| 493 | + for s in self.stages: |
| 494 | + s.grad_checkpointing = enable |
| 495 | + |
| 496 | + @torch.jit.ignore |
| 497 | + def get_classifier(self): |
| 498 | + return self.head.fc |
| 499 | + |
| 500 | + def reset_classifier(self, num_classes, global_pool='avg'): |
| 501 | + self.num_classes = num_classes |
| 502 | + if num_classes > 0: |
| 503 | + self.head = ClassifierHead( |
| 504 | + self.num_features, |
| 505 | + num_classes=num_classes, |
| 506 | + pool_type=global_pool, |
| 507 | + drop_rate=self.drop_rate, |
| 508 | + use_last_conv=self.use_last_conv, |
| 509 | + class_expand=self.class_expand, |
| 510 | + use_lab=self.use_lab) |
| 511 | + else: |
| 512 | + if global_pool: |
| 513 | + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) |
| 514 | + else: |
| 515 | + self.head = nn.Identity() |
| 516 | + |
| 517 | + def forward_features(self, x): |
| 518 | + x = self.stem(x) |
| 519 | + return self.stages(x) |
| 520 | + |
| 521 | + def forward_head(self, x, pre_logits: bool = False): |
| 522 | + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) |
| 523 | + |
| 524 | + def forward(self, x): |
| 525 | + x = self.forward_features(x) |
| 526 | + x = self.forward_head(x) |
| 527 | + return x |
| 528 | + |
| 529 | + |
| 530 | +model_cfgs = dict( |
| 531 | + # PP-HGNet |
| 532 | + hgnet_tiny={ |
| 533 | + "stem_type": 'v1', |
| 534 | + "stem_chs": [48, 48, 96], |
| 535 | + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num |
| 536 | + "stage1": [96, 96, 224, 1, False, False, 3, 5], |
| 537 | + "stage2": [224, 128, 448, 1, True, False, 3, 5], |
| 538 | + "stage3": [448, 160, 512, 2, True, False, 3, 5], |
| 539 | + "stage4": [512, 192, 768, 1, True, False, 3, 5], |
| 540 | + }, |
| 541 | + hgnet_small={ |
| 542 | + "stem_type": 'v1', |
| 543 | + "stem_chs": [64, 64, 128], |
| 544 | + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num |
| 545 | + "stage1": [128, 128, 256, 1, False, False, 3, 6], |
| 546 | + "stage2": [256, 160, 512, 1, True, False, 3, 6], |
| 547 | + "stage3": [512, 192, 768, 2, True, False, 3, 6], |
| 548 | + "stage4": [768, 224, 1024, 1, True, False, 3, 6], |
| 549 | + }, |
| 550 | + hgnet_base={ |
| 551 | + "stem_type": 'v1', |
| 552 | + "stem_chs": [96, 96, 160], |
| 553 | + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num |
| 554 | + "stage1": [160, 192, 320, 1, False, False, 3, 7], |
| 555 | + "stage2": [320, 224, 640, 2, True, False, 3, 7], |
| 556 | + "stage3": [640, 256, 960, 3, True, False, 3, 7], |
| 557 | + "stage4": [960, 288, 1280, 2, True, False, 3, 7], |
| 558 | + }, |
| 559 | + # PP-HGNetv2 |
| 560 | + hgnetv2_b0={ |
| 561 | + "stem_type": 'v2', |
| 562 | + "stem_chs": [16, 16], |
| 563 | + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num |
| 564 | + "stage1": [16, 16, 64, 1, False, False, 3, 3], |
| 565 | + "stage2": [64, 32, 256, 1, True, False, 3, 3], |
| 566 | + "stage3": [256, 64, 512, 2, True, True, 5, 3], |
| 567 | + "stage4": [512, 128, 1024, 1, True, True, 5, 3], |
| 568 | + }, |
| 569 | + hgnetv2_b1={ |
| 570 | + "stem_type": 'v2', |
| 571 | + "stem_chs": [24, 32], |
| 572 | + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num |
| 573 | + "stage1": [32, 32, 64, 1, False, False, 3, 3], |
| 574 | + "stage2": [64, 48, 256, 1, True, False, 3, 3], |
| 575 | + "stage3": [256, 96, 512, 2, True, True, 5, 3], |
| 576 | + "stage4": [512, 192, 1024, 1, True, True, 5, 3], |
| 577 | + }, |
| 578 | + hgnetv2_b2={ |
| 579 | + "stem_type": 'v2', |
| 580 | + "stem_chs": [24, 32], |
| 581 | + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num |
| 582 | + "stage1": [32, 32, 96, 1, False, False, 3, 4], |
| 583 | + "stage2": [96, 64, 384, 1, True, False, 3, 4], |
| 584 | + "stage3": [384, 128, 768, 3, True, True, 5, 4], |
| 585 | + "stage4": [768, 256, 1536, 1, True, True, 5, 4], |
| 586 | + }, |
| 587 | + hgnetv2_b3={ |
| 588 | + "stem_type": 'v2', |
| 589 | + "stem_chs": [24, 32], |
| 590 | + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num |
| 591 | + "stage1": [32, 32, 128, 1, False, False, 3, 5], |
| 592 | + "stage2": [128, 64, 512, 1, True, False, 3, 5], |
| 593 | + "stage3": [512, 128, 1024, 3, True, True, 5, 5], |
| 594 | + "stage4": [1024, 256, 2048, 1, True, True, 5, 5], |
| 595 | + }, |
| 596 | + hgnetv2_b4={ |
| 597 | + "stem_type": 'v2', |
| 598 | + "stem_chs": [32, 48], |
| 599 | + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num |
| 600 | + "stage1": [48, 48, 128, 1, False, False, 3, 6], |
| 601 | + "stage2": [128, 96, 512, 1, True, False, 3, 6], |
| 602 | + "stage3": [512, 192, 1024, 3, True, True, 5, 6], |
| 603 | + "stage4": [1024, 384, 2048, 1, True, True, 5, 6], |
| 604 | + }, |
| 605 | + hgnetv2_b5={ |
| 606 | + "stem_type": 'v2', |
| 607 | + "stem_chs": [32, 64], |
| 608 | + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num |
| 609 | + "stage1": [64, 64, 128, 1, False, False, 3, 6], |
| 610 | + "stage2": [128, 128, 512, 2, True, False, 3, 6], |
| 611 | + "stage3": [512, 256, 1024, 5, True, True, 5, 6], |
| 612 | + "stage4": [1024, 512, 2048, 2, True, True, 5, 6], |
| 613 | + }, |
| 614 | + hgnetv2_b6={ |
| 615 | + "stem_type": 'v2', |
| 616 | + "stem_chs": [48, 96], |
| 617 | + # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num |
| 618 | + "stage1": [96, 96, 192, 2, False, False, 3, 6], |
| 619 | + "stage2": [192, 192, 512, 3, True, False, 3, 6], |
| 620 | + "stage3": [512, 384, 1024, 6, True, True, 5, 6], |
| 621 | + "stage4": [1024, 768, 2048, 3, True, True, 5, 6], |
| 622 | + }, |
| 623 | +) |
| 624 | + |
| 625 | + |
| 626 | +def _create_hgnet(variant, pretrained=False, **kwargs): |
| 627 | + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) |
| 628 | + return build_model_with_cfg( |
| 629 | + HighPerfGpuNet, |
| 630 | + variant, |
| 631 | + pretrained, |
| 632 | + model_cfg=model_cfgs[variant], |
| 633 | + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), |
| 634 | + **kwargs, |
| 635 | + ) |
| 636 | + |
| 637 | + |
| 638 | +def _cfg(url='', **kwargs): |
| 639 | + return { |
| 640 | + 'url': url, |
| 641 | + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), |
| 642 | + 'crop_pct': 0.965, 'interpolation': 'bicubic', |
| 643 | + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, |
| 644 | + 'classifier': 'head.fc', 'first_conv': 'stem.stem1.conv', |
| 645 | + 'test_crop_pct': 1.0, 'test_input_size': (3, 288, 288), |
| 646 | + **kwargs, |
| 647 | + } |
| 648 | + |
| 649 | + |
| 650 | +default_cfgs = generate_default_cfgs({ |
| 651 | + 'hgnet_tiny.paddle_in1k': _cfg( |
| 652 | + first_conv='stem.stem.0.conv', |
| 653 | + hf_hub_id='timm/'), |
| 654 | + 'hgnet_tiny.ssld_in1k': _cfg( |
| 655 | + first_conv='stem.stem.0.conv', |
| 656 | + hf_hub_id='timm/'), |
| 657 | + 'hgnet_small.paddle_in1k': _cfg( |
| 658 | + first_conv='stem.stem.0.conv', |
| 659 | + hf_hub_id='timm/'), |
| 660 | + 'hgnet_small.ssld_in1k': _cfg( |
| 661 | + first_conv='stem.stem.0.conv', |
| 662 | + hf_hub_id='timm/'), |
| 663 | + 'hgnet_base.ssld_in1k': _cfg( |
| 664 | + first_conv='stem.stem.0.conv', |
| 665 | + hf_hub_id='timm/'), |
| 666 | + 'hgnetv2_b0.ssld_stage2_ft_in1k': _cfg( |
| 667 | + hf_hub_id='timm/'), |
| 668 | + 'hgnetv2_b0.ssld_stage1_in22k_in1k': _cfg( |
| 669 | + hf_hub_id='timm/'), |
| 670 | + 'hgnetv2_b1.ssld_stage2_ft_in1k': _cfg( |
| 671 | + hf_hub_id='timm/'), |
| 672 | + 'hgnetv2_b1.ssld_stage1_in22k_in1k': _cfg( |
| 673 | + hf_hub_id='timm/'), |
| 674 | + 'hgnetv2_b2.ssld_stage2_ft_in1k': _cfg( |
| 675 | + hf_hub_id='timm/'), |
| 676 | + 'hgnetv2_b2.ssld_stage1_in22k_in1k': _cfg( |
| 677 | + hf_hub_id='timm/'), |
| 678 | + 'hgnetv2_b3.ssld_stage2_ft_in1k': _cfg( |
| 679 | + hf_hub_id='timm/'), |
| 680 | + 'hgnetv2_b3.ssld_stage1_in22k_in1k': _cfg( |
| 681 | + hf_hub_id='timm/'), |
| 682 | + 'hgnetv2_b4.ssld_stage2_ft_in1k': _cfg( |
| 683 | + hf_hub_id='timm/'), |
| 684 | + 'hgnetv2_b4.ssld_stage1_in22k_in1k': _cfg( |
| 685 | + hf_hub_id='timm/'), |
| 686 | + 'hgnetv2_b5.ssld_stage2_ft_in1k': _cfg( |
| 687 | + hf_hub_id='timm/'), |
| 688 | + 'hgnetv2_b5.ssld_stage1_in22k_in1k': _cfg( |
| 689 | + hf_hub_id='timm/'), |
| 690 | + 'hgnetv2_b6.ssld_stage2_ft_in1k': _cfg( |
| 691 | + hf_hub_id='timm/'), |
| 692 | + 'hgnetv2_b6.ssld_stage1_in22k_in1k': _cfg( |
| 693 | + hf_hub_id='timm/'), |
| 694 | +}) |
| 695 | + |
| 696 | + |
| 697 | +@register_model |
| 698 | +def hgnet_tiny(pretrained=False, **kwargs) -> HighPerfGpuNet: |
| 699 | + return _create_hgnet('hgnet_tiny', pretrained=pretrained, **kwargs) |
| 700 | + |
| 701 | + |
| 702 | +@register_model |
| 703 | +def hgnet_small(pretrained=False, **kwargs) -> HighPerfGpuNet: |
| 704 | + return _create_hgnet('hgnet_small', pretrained=pretrained, **kwargs) |
| 705 | + |
| 706 | + |
| 707 | +@register_model |
| 708 | +def hgnet_base(pretrained=False, **kwargs) -> HighPerfGpuNet: |
| 709 | + return _create_hgnet('hgnet_base', pretrained=pretrained, **kwargs) |
| 710 | + |
| 711 | + |
| 712 | +@register_model |
| 713 | +def hgnetv2_b0(pretrained=False, **kwargs) -> HighPerfGpuNet: |
| 714 | + return _create_hgnet('hgnetv2_b0', pretrained=pretrained, use_lab=True, **kwargs) |
| 715 | + |
| 716 | + |
| 717 | +@register_model |
| 718 | +def hgnetv2_b1(pretrained=False, **kwargs) -> HighPerfGpuNet: |
| 719 | + return _create_hgnet('hgnetv2_b1', pretrained=pretrained, use_lab=True, **kwargs) |
| 720 | + |
| 721 | + |
| 722 | +@register_model |
| 723 | +def hgnetv2_b2(pretrained=False, **kwargs) -> HighPerfGpuNet: |
| 724 | + return _create_hgnet('hgnetv2_b2', pretrained=pretrained, use_lab=True, **kwargs) |
| 725 | + |
| 726 | + |
| 727 | +@register_model |
| 728 | +def hgnetv2_b3(pretrained=False, **kwargs) -> HighPerfGpuNet: |
| 729 | + return _create_hgnet('hgnetv2_b3', pretrained=pretrained, use_lab=True, **kwargs) |
| 730 | + |
| 731 | + |
| 732 | +@register_model |
| 733 | +def hgnetv2_b4(pretrained=False, **kwargs) -> HighPerfGpuNet: |
| 734 | + return _create_hgnet('hgnetv2_b4', pretrained=pretrained, **kwargs) |
| 735 | + |
| 736 | + |
| 737 | +@register_model |
| 738 | +def hgnetv2_b5(pretrained=False, **kwargs) -> HighPerfGpuNet: |
| 739 | + return _create_hgnet('hgnetv2_b5', pretrained=pretrained, **kwargs) |
| 740 | + |
| 741 | + |
| 742 | +@register_model |
| 743 | +def hgnetv2_b6(pretrained=False, **kwargs) -> HighPerfGpuNet: |
| 744 | + return _create_hgnet('hgnetv2_b6', pretrained=pretrained, **kwargs) |
0 commit comments