1313 from torch .hub import _get_torch_home as get_dir
1414
1515from timm import __version__
16+
1617try :
1718 from huggingface_hub import HfApi , HfFolder , Repository , hf_hub_download , hf_hub_url
1819 hf_hub_download = partial (hf_hub_download , library_name = "timm" , library_version = __version__ )
@@ -55,7 +56,7 @@ def download_cached_file(url, check_hash=True, progress=False):
5556
5657def has_hf_hub (necessary = False ):
5758 if not _has_hf_hub and necessary :
58- # if no HF Hub module installed and it is necessary to continue, raise error
59+ # if no HF Hub module installed, and it is necessary to continue, raise error
5960 raise RuntimeError (
6061 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.' )
6162 return _has_hf_hub
@@ -78,7 +79,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
7879
7980def _download_from_hf (model_id : str , filename : str ):
8081 hf_model_id , hf_revision = hf_split (model_id )
81- return hf_hub_download (hf_model_id , filename , revision = hf_revision , cache_dir = get_cache_dir ( 'hf' ) )
82+ return hf_hub_download (hf_model_id , filename , revision = hf_revision )
8283
8384
8485def load_model_config_from_hf (model_id : str ):
@@ -91,9 +92,9 @@ def load_model_config_from_hf(model_id: str):
9192 return pretrained_cfg , model_name
9293
9394
94- def load_state_dict_from_hf (model_id : str ):
95+ def load_state_dict_from_hf (model_id : str , filename : str = 'pytorch_model.bin' ):
9596 assert has_hf_hub (True )
96- cached_file = _download_from_hf (model_id , 'pytorch_model.bin' )
97+ cached_file = _download_from_hf (model_id , filename )
9798 state_dict = torch .load (cached_file , map_location = 'cpu' )
9899 return state_dict
99100
0 commit comments