Skip to content

Commit d70c481

Browse files
committed
override device kwargs of base nn classes
1 parent e44f14d commit d70c481

File tree

1 file changed

+44
-5
lines changed

1 file changed

+44
-5
lines changed

timm/models/_builder.py

+44-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from contextlib import contextmanager, nullcontext
12
import dataclasses
23
import logging
34
import os
45
from copy import deepcopy
56
from pathlib import Path
6-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7+
from typing import Any, Callable, Dict, Optional, Tuple, Union
78

9+
import torch
810
from torch import nn as nn
911
from torch.hub import load_state_dict_from_url
1012

@@ -360,6 +362,27 @@ def resolve_pretrained_cfg(
360362
return pretrained_cfg
361363

362364

365+
@contextmanager
366+
def make_meta_init(*classes):
367+
def create_new_init(cls):
368+
old_init = cls.__init__
369+
def new_init(self, *args, **kwargs):
370+
kwargs.update(device="meta")
371+
old_init(self, *args, **kwargs)
372+
return new_init
373+
374+
original_dict = dict()
375+
for cls in classes:
376+
original_dict[cls] = cls.__init__
377+
cls.__init__ = create_new_init(cls)
378+
379+
yield
380+
381+
# restore original __init__()
382+
for cls, old_init in original_dict.items():
383+
cls.__init__ = old_init
384+
385+
363386
def build_model_with_cfg(
364387
model_cls: Callable,
365388
variant: str,
@@ -419,11 +442,27 @@ def build_model_with_cfg(
419442
if 'feature_cls' in kwargs:
420443
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
421444

445+
# use meta-device init to speed up loading pretrained weights.
446+
# when num_classes is changed, we can't use meta device init since we need
447+
# the original __init__() to initialize head from scratch.
448+
num_classes = 0 if features else kwargs.get("num_classes", pretrained_cfg["num_classes"])
449+
use_meta_init = (
450+
pretrained
451+
and (num_classes == 0 or num_classes == pretrained_cfg["num_classes"])
452+
)
453+
422454
# Instantiate the model
423-
if model_cfg is None:
424-
model = model_cls(**kwargs)
425-
else:
426-
model = model_cls(cfg=model_cfg, **kwargs)
455+
base_classes = [nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.LayerNorm]
456+
with make_meta_init(*base_classes) if use_meta_init else nullcontext():
457+
if model_cfg is None:
458+
model = model_cls(**kwargs)
459+
else:
460+
model = model_cls(cfg=model_cfg, **kwargs)
461+
462+
# convert meta-device tensors to concrete tensors
463+
device = kwargs.get("device", torch.get_default_device())
464+
model._apply(lambda t: (torch.empty_like(t, device=device) if t.is_meta else t))
465+
427466
model.pretrained_cfg = pretrained_cfg
428467
model.default_cfg = model.pretrained_cfg # alias for backwards compat
429468

0 commit comments

Comments
 (0)