Skip to content

Commit d199f66

Browse files
authored
Merge pull request #1467 from rwightman/clip_laion2b
Adding support for fine-tune CLIP LAION-2B image tower weights for B/32, L/14, H/14, and g/14.
2 parents a520da9 + 33e30f8 commit d199f66

File tree

9 files changed

+233
-33
lines changed

9 files changed

+233
-33
lines changed

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
torch>=1.4.0
2-
torchvision>=0.5.0
1+
torch>=1.7
2+
torchvision
33
pyyaml
4+
huggingface_hub

setup.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
# 3 - Alpha
2626
# 4 - Beta
2727
# 5 - Production/Stable
28-
'Development Status :: 3 - Alpha',
28+
'Development Status :: 4 - Beta',
2929
'Intended Audience :: Education',
3030
'Intended Audience :: Science/Research',
3131
'License :: OSI Approved :: Apache Software License',
3232
'Programming Language :: Python :: 3.6',
3333
'Programming Language :: Python :: 3.7',
3434
'Programming Language :: Python :: 3.8',
35+
'Programming Language :: Python :: 3.9',
36+
'Programming Language :: Python :: 3.10',
3537
'Topic :: Scientific/Engineering',
3638
'Topic :: Scientific/Engineering :: Artificial Intelligence',
3739
'Topic :: Software Development',
@@ -40,9 +42,10 @@
4042
],
4143

4244
# Note that this is a string of words separated by whitespace, not a list.
43-
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet',
45+
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
4446
packages=find_packages(exclude=['convert', 'tests', 'results']),
4547
include_package_data=True,
46-
install_requires=['torch >= 1.4', 'torchvision'],
48+
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'],
4749
python_requires='>=3.6',
4850
)
51+

timm/data/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
66
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
77
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
8+
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
9+
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)

timm/models/helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ def _resolve_pretrained_source(pretrained_cfg):
138138
# hf-hub available as alternate weight source in default_cfg
139139
load_from = 'hf-hub'
140140
pretrained_loc = hf_hub_id
141+
if load_from == 'hf-hub' and 'hf_hub_filename' in pretrained_cfg:
142+
# if a filename override is set, return tuple for location w/ (hub_id, filename)
143+
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
141144
return load_from, pretrained_loc
142145

143146

@@ -246,7 +249,10 @@ def load_pretrained(
246249
pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH)
247250
elif load_from == 'hf-hub':
248251
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
249-
state_dict = load_state_dict_from_hf(pretrained_loc)
252+
if isinstance(pretrained_loc, (list, tuple)):
253+
state_dict = load_state_dict_from_hf(*pretrained_loc)
254+
else:
255+
state_dict = load_state_dict_from_hf(pretrained_loc)
250256
else:
251257
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
252258
return

timm/models/hub.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch.hub import _get_torch_home as get_dir
1414

1515
from timm import __version__
16+
1617
try:
1718
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
1819
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
@@ -55,7 +56,7 @@ def download_cached_file(url, check_hash=True, progress=False):
5556

5657
def has_hf_hub(necessary=False):
5758
if not _has_hf_hub and necessary:
58-
# if no HF Hub module installed and it is necessary to continue, raise error
59+
# if no HF Hub module installed, and it is necessary to continue, raise error
5960
raise RuntimeError(
6061
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
6162
return _has_hf_hub
@@ -78,7 +79,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
7879

7980
def _download_from_hf(model_id: str, filename: str):
8081
hf_model_id, hf_revision = hf_split(model_id)
81-
return hf_hub_download(hf_model_id, filename, revision=hf_revision, cache_dir=get_cache_dir('hf'))
82+
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
8283

8384

8485
def load_model_config_from_hf(model_id: str):
@@ -91,9 +92,9 @@ def load_model_config_from_hf(model_id: str):
9192
return pretrained_cfg, model_name
9293

9394

94-
def load_state_dict_from_hf(model_id: str):
95+
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
9596
assert has_hf_hub(True)
96-
cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
97+
cached_file = _download_from_hf(model_id, filename)
9798
state_dict = torch.load(cached_file, map_location='cpu')
9899
return state_dict
99100

timm/models/layers/patch_embed.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,16 @@
1515
class PatchEmbed(nn.Module):
1616
""" 2D Image to Patch Embedding
1717
"""
18-
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
18+
def __init__(
19+
self,
20+
img_size=224,
21+
patch_size=16,
22+
in_chans=3,
23+
embed_dim=768,
24+
norm_layer=None,
25+
flatten=True,
26+
bias=True,
27+
):
1928
super().__init__()
2029
img_size = to_2tuple(img_size)
2130
patch_size = to_2tuple(patch_size)
@@ -25,7 +34,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_
2534
self.num_patches = self.grid_size[0] * self.grid_size[1]
2635
self.flatten = flatten
2736

28-
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
37+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
2938
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
3039

3140
def forward(self, x):

0 commit comments

Comments
 (0)