|
| 1 | +from contextlib import contextmanager, nullcontext |
1 | 2 | import dataclasses
|
2 | 3 | import logging
|
3 | 4 | import os
|
4 | 5 | from copy import deepcopy
|
5 | 6 | from pathlib import Path
|
6 |
| -from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| 7 | +from typing import Any, Callable, Dict, Optional, Tuple, Union |
7 | 8 |
|
| 9 | +import torch |
8 | 10 | from torch import nn as nn
|
9 | 11 | from torch.hub import load_state_dict_from_url
|
10 | 12 |
|
@@ -360,6 +362,27 @@ def resolve_pretrained_cfg(
|
360 | 362 | return pretrained_cfg
|
361 | 363 |
|
362 | 364 |
|
| 365 | +@contextmanager |
| 366 | +def make_meta_init(*classes): |
| 367 | + def create_new_init(cls): |
| 368 | + old_init = cls.__init__ |
| 369 | + def new_init(self, *args, **kwargs): |
| 370 | + kwargs.update(device="meta") |
| 371 | + old_init(self, *args, **kwargs) |
| 372 | + return new_init |
| 373 | + |
| 374 | + original_dict = dict() |
| 375 | + for cls in classes: |
| 376 | + original_dict[cls] = cls.__init__ |
| 377 | + cls.__init__ = create_new_init(cls) |
| 378 | + |
| 379 | + yield |
| 380 | + |
| 381 | + # restore original __init__() |
| 382 | + for cls, old_init in original_dict.items(): |
| 383 | + cls.__init__ = old_init |
| 384 | + |
| 385 | + |
363 | 386 | def build_model_with_cfg(
|
364 | 387 | model_cls: Callable,
|
365 | 388 | variant: str,
|
@@ -419,11 +442,27 @@ def build_model_with_cfg(
|
419 | 442 | if 'feature_cls' in kwargs:
|
420 | 443 | feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
|
421 | 444 |
|
| 445 | + # use meta-device init to speed up loading pretrained weights. |
| 446 | + # when num_classes is changed, we can't use meta device init since we need |
| 447 | + # the original __init__() to initialize head from scratch. |
| 448 | + num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"]) |
| 449 | + use_meta_init = ( |
| 450 | + pretrained |
| 451 | + and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"]) |
| 452 | + ) |
| 453 | + |
422 | 454 | # Instantiate the model
|
423 |
| - if model_cfg is None: |
424 |
| - model = model_cls(**kwargs) |
425 |
| - else: |
426 |
| - model = model_cls(cfg=model_cfg, **kwargs) |
| 455 | + base_classes = [nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.LayerNorm] |
| 456 | + with make_meta_init(*base_classes) if use_meta_init else nullcontext(): |
| 457 | + if model_cfg is None: |
| 458 | + model = model_cls(**kwargs) |
| 459 | + else: |
| 460 | + model = model_cls(cfg=model_cfg, **kwargs) |
| 461 | + |
| 462 | + # convert meta-device tensors to concrete tensors |
| 463 | + device = kwargs.get("device", torch.get_default_device()) |
| 464 | + model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t)) |
| 465 | + |
427 | 466 | model.pretrained_cfg = pretrained_cfg
|
428 | 467 | model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
429 | 468 |
|
|
0 commit comments