1717
1818try :
1919 import safetensors .torch
20-
2120 _has_safetensors = True
2221except ImportError :
2322 _has_safetensors = False
3332try :
3433 from huggingface_hub import HfApi , hf_hub_download , model_info
3534 from huggingface_hub .utils import EntryNotFoundError , RepositoryNotFoundError
36-
3735 hf_hub_download = partial (hf_hub_download , library_name = "timm" , library_version = __version__ )
3836 _has_hf_hub = True
3937except ImportError :
4240
4341_logger = logging .getLogger (__name__ )
4442
45- __all__ = [
46- 'get_cache_dir' ,
47- 'download_cached_file' ,
48- 'has_hf_hub' ,
49- 'hf_split' ,
50- 'load_model_config_from_hf' ,
51- 'load_state_dict_from_hf' ,
52- 'save_for_hf' ,
53- 'push_to_hf_hub' ,
54- ]
43+ __all__ = ['get_cache_dir' , 'download_cached_file' , 'has_hf_hub' , 'hf_split' , 'load_model_config_from_hf' ,
44+ 'load_state_dict_from_hf' , 'save_for_hf' , 'push_to_hf_hub' ]
5545
5646# Default name for a weights file hosted on the Huggingface Hub.
5747HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
@@ -76,10 +66,10 @@ def get_cache_dir(child_dir: str = ''):
7666
7767
7868def download_cached_file (
79- url : Union [str , List [str ], Tuple [str , str ]],
80- check_hash : bool = True ,
81- progress : bool = False ,
82- cache_dir : Optional [Union [str , Path ]] = None ,
69+ url : Union [str , List [str ], Tuple [str , str ]],
70+ check_hash : bool = True ,
71+ progress : bool = False ,
72+ cache_dir : Optional [Union [str , Path ]] = None ,
8373):
8474 if isinstance (url , (list , tuple )):
8575 url , filename = url
@@ -102,9 +92,9 @@ def download_cached_file(
10292
10393
10494def check_cached_file (
105- url : Union [str , List [str ], Tuple [str , str ]],
106- check_hash : bool = True ,
107- cache_dir : Optional [Union [str , Path ]] = None ,
95+ url : Union [str , List [str ], Tuple [str , str ]],
96+ check_hash : bool = True ,
97+ cache_dir : Optional [Union [str , Path ]] = None ,
10898):
10999 if isinstance (url , (list , tuple )):
110100 url , filename = url
@@ -121,7 +111,7 @@ def check_cached_file(
121111 if hash_prefix :
122112 with open (cached_file , 'rb' ) as f :
123113 hd = hashlib .sha256 (f .read ()).hexdigest ()
124- if hd [: len (hash_prefix )] != hash_prefix :
114+ if hd [:len (hash_prefix )] != hash_prefix :
125115 return False
126116 return True
127117 return False
@@ -131,8 +121,7 @@ def has_hf_hub(necessary: bool = False):
131121 if not _has_hf_hub and necessary :
132122 # if no HF Hub module installed, and it is necessary to continue, raise error
133123 raise RuntimeError (
134- 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.'
135- )
124+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.' )
136125 return _has_hf_hub
137126
138127
@@ -152,9 +141,9 @@ def load_cfg_from_json(json_file: Union[str, Path]):
152141
153142
154143def download_from_hf (
155- model_id : str ,
156- filename : str ,
157- cache_dir : Optional [Union [str , Path ]] = None ,
144+ model_id : str ,
145+ filename : str ,
146+ cache_dir : Optional [Union [str , Path ]] = None ,
158147):
159148 hf_model_id , hf_revision = hf_split (model_id )
160149 return hf_hub_download (
@@ -166,8 +155,8 @@ def download_from_hf(
166155
167156
168157def _parse_model_cfg (
169- cfg : Dict [str , Any ],
170- extra_fields : Dict [str , Any ],
158+ cfg : Dict [str , Any ],
159+ extra_fields : Dict [str , Any ],
171160) -> Tuple [Dict [str , Any ], str , Dict [str , Any ]]:
172161 """"""
173162 # legacy "single‑dict" → split
@@ -178,7 +167,7 @@ def _parse_model_cfg(
178167 "num_features" : pretrained_cfg .pop ("num_features" , None ),
179168 "pretrained_cfg" : pretrained_cfg ,
180169 }
181- if "labels" in pretrained_cfg : # rename ‑‑> label_names
170+ if "labels" in pretrained_cfg : # rename ‑‑> label_names
182171 pretrained_cfg ["label_names" ] = pretrained_cfg .pop ("labels" )
183172
184173 pretrained_cfg = cfg ["pretrained_cfg" ]
@@ -198,8 +187,8 @@ def _parse_model_cfg(
198187
199188
200189def load_model_config_from_hf (
201- model_id : str ,
202- cache_dir : Optional [Union [str , Path ]] = None ,
190+ model_id : str ,
191+ cache_dir : Optional [Union [str , Path ]] = None ,
203192):
204193 """Original HF‑Hub loader (unchanged download, shared parsing)."""
205194 assert has_hf_hub (True )
@@ -209,7 +198,7 @@ def load_model_config_from_hf(
209198
210199
211200def load_model_config_from_path (
212- model_path : Union [str , Path ],
201+ model_path : Union [str , Path ],
213202):
214203 """Load from ``<model_path>/config.json`` on the local filesystem."""
215204 model_path = Path (model_path )
@@ -222,10 +211,10 @@ def load_model_config_from_path(
222211
223212
224213def load_state_dict_from_hf (
225- model_id : str ,
226- filename : str = HF_WEIGHTS_NAME ,
227- weights_only : bool = False ,
228- cache_dir : Optional [Union [str , Path ]] = None ,
214+ model_id : str ,
215+ filename : str = HF_WEIGHTS_NAME ,
216+ weights_only : bool = False ,
217+ cache_dir : Optional [Union [str , Path ]] = None ,
229218):
230219 assert has_hf_hub (True )
231220 hf_model_id , hf_revision = hf_split (model_id )
@@ -242,8 +231,7 @@ def load_state_dict_from_hf(
242231 )
243232 _logger .info (
244233 f"[{ model_id } ] Safe alternative available for '{ filename } ' "
245- f"(as '{ safe_filename } '). Loading weights using safetensors."
246- )
234+ f"(as '{ safe_filename } '). Loading weights using safetensors." )
247235 return safetensors .torch .load_file (cached_safe_file , device = "cpu" )
248236 except EntryNotFoundError :
249237 pass
@@ -275,10 +263,9 @@ def load_state_dict_from_hf(
275263)
276264_EXT_PRIORITY = ('.safetensors' , '.pth' , '.pth.tar' , '.bin' )
277265
278-
279266def load_state_dict_from_path (
280- path : str ,
281- weights_only : bool = False ,
267+ path : str ,
268+ weights_only : bool = False ,
282269):
283270 found_file = None
284271 for fname in _PREFERRED_FILES :
@@ -293,7 +280,10 @@ def load_state_dict_from_path(
293280 files = sorted (path .glob (f"*{ ext } " ))
294281 if files :
295282 if len (files ) > 1 :
296- logging .warning (f"Multiple { ext } checkpoints in { path } : { names } . " f"Using '{ files [0 ].name } '." )
283+ logging .warning (
284+ f"Multiple { ext } checkpoints in { path } : { names } . "
285+ f"Using '{ files [0 ].name } '."
286+ )
297287 found_file = files [0 ]
298288
299289 if not found_file :
@@ -307,10 +297,10 @@ def load_state_dict_from_path(
307297
308298
309299def load_custom_from_hf (
310- model_id : str ,
311- filename : str ,
312- model : torch .nn .Module ,
313- cache_dir : Optional [Union [str , Path ]] = None ,
300+ model_id : str ,
301+ filename : str ,
302+ model : torch .nn .Module ,
303+ cache_dir : Optional [Union [str , Path ]] = None ,
314304):
315305 assert has_hf_hub (True )
316306 hf_model_id , hf_revision = hf_split (model_id )
@@ -324,7 +314,10 @@ def load_custom_from_hf(
324314
325315
326316def save_config_for_hf (
327- model : torch .nn .Module , config_path : str , model_config : Optional [dict ] = None , model_args : Optional [dict ] = None
317+ model : torch .nn .Module ,
318+ config_path : str ,
319+ model_config : Optional [dict ] = None ,
320+ model_args : Optional [dict ] = None
328321):
329322 model_config = model_config or {}
330323 hf_config = {}
@@ -343,8 +336,7 @@ def save_config_for_hf(
343336 if 'labels' in model_config :
344337 _logger .warning (
345338 "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
346- " Renaming provided 'labels' field to 'label_names'."
347- )
339+ " Renaming provided 'labels' field to 'label_names'." )
348340 model_config .setdefault ('label_names' , model_config .pop ('labels' ))
349341
350342 label_names = model_config .pop ('label_names' , None )
@@ -371,11 +363,11 @@ def save_config_for_hf(
371363
372364
373365def save_for_hf (
374- model : torch .nn .Module ,
375- save_directory : str ,
376- model_config : Optional [dict ] = None ,
377- model_args : Optional [dict ] = None ,
378- safe_serialization : Union [bool , Literal ["both" ]] = False ,
366+ model : torch .nn .Module ,
367+ save_directory : str ,
368+ model_config : Optional [dict ] = None ,
369+ model_args : Optional [dict ] = None ,
370+ safe_serialization : Union [bool , Literal ["both" ]] = False ,
379371):
380372 assert has_hf_hub (True )
381373 save_directory = Path (save_directory )
@@ -399,18 +391,18 @@ def save_for_hf(
399391
400392
401393def push_to_hf_hub (
402- model : torch .nn .Module ,
403- repo_id : str ,
404- commit_message : str = 'Add model' ,
405- token : Optional [str ] = None ,
406- revision : Optional [str ] = None ,
407- private : bool = False ,
408- create_pr : bool = False ,
409- model_config : Optional [dict ] = None ,
410- model_card : Optional [dict ] = None ,
411- model_args : Optional [dict ] = None ,
412- task_name : str = 'image-classification' ,
413- safe_serialization : Union [bool , Literal ["both" ]] = 'both' ,
394+ model : torch .nn .Module ,
395+ repo_id : str ,
396+ commit_message : str = 'Add model' ,
397+ token : Optional [str ] = None ,
398+ revision : Optional [str ] = None ,
399+ private : bool = False ,
400+ create_pr : bool = False ,
401+ model_config : Optional [dict ] = None ,
402+ model_card : Optional [dict ] = None ,
403+ model_args : Optional [dict ] = None ,
404+ task_name : str = 'image-classification' ,
405+ safe_serialization : Union [bool , Literal ["both" ]] = 'both' ,
414406):
415407 """
416408 Arguments:
@@ -460,9 +452,9 @@ def push_to_hf_hub(
460452
461453
462454def generate_readme (
463- model_card : dict ,
464- model_name : str ,
465- task_name : str = 'image-classification' ,
455+ model_card : dict ,
456+ model_name : str ,
457+ task_name : str = 'image-classification' ,
466458):
467459 tags = model_card .get ('tags' , None ) or [task_name , 'timm' , 'transformers' ]
468460 readme_text = "---\n "
0 commit comments