Skip to content

Commit f619bef

Browse files
committed
Revert "Running formatting with command from CONTRIBUTING.md"
This reverts commit ed00d06. Reducing diff to keep pull request only for functional change.
1 parent ffddf59 commit f619bef

File tree

1 file changed

+60
-68
lines changed

1 file changed

+60
-68
lines changed

timm/models/_hub.py

Lines changed: 60 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
try:
1919
import safetensors.torch
20-
2120
_has_safetensors = True
2221
except ImportError:
2322
_has_safetensors = False
@@ -33,7 +32,6 @@
3332
try:
3433
from huggingface_hub import HfApi, hf_hub_download, model_info
3534
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
36-
3735
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
3836
_has_hf_hub = True
3937
except ImportError:
@@ -42,16 +40,8 @@
4240

4341
_logger = logging.getLogger(__name__)
4442

45-
__all__ = [
46-
'get_cache_dir',
47-
'download_cached_file',
48-
'has_hf_hub',
49-
'hf_split',
50-
'load_model_config_from_hf',
51-
'load_state_dict_from_hf',
52-
'save_for_hf',
53-
'push_to_hf_hub',
54-
]
43+
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
44+
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
5545

5646
# Default name for a weights file hosted on the Huggingface Hub.
5747
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
@@ -76,10 +66,10 @@ def get_cache_dir(child_dir: str = ''):
7666

7767

7868
def download_cached_file(
79-
url: Union[str, List[str], Tuple[str, str]],
80-
check_hash: bool = True,
81-
progress: bool = False,
82-
cache_dir: Optional[Union[str, Path]] = None,
69+
url: Union[str, List[str], Tuple[str, str]],
70+
check_hash: bool = True,
71+
progress: bool = False,
72+
cache_dir: Optional[Union[str, Path]] = None,
8373
):
8474
if isinstance(url, (list, tuple)):
8575
url, filename = url
@@ -102,9 +92,9 @@ def download_cached_file(
10292

10393

10494
def check_cached_file(
105-
url: Union[str, List[str], Tuple[str, str]],
106-
check_hash: bool = True,
107-
cache_dir: Optional[Union[str, Path]] = None,
95+
url: Union[str, List[str], Tuple[str, str]],
96+
check_hash: bool = True,
97+
cache_dir: Optional[Union[str, Path]] = None,
10898
):
10999
if isinstance(url, (list, tuple)):
110100
url, filename = url
@@ -121,7 +111,7 @@ def check_cached_file(
121111
if hash_prefix:
122112
with open(cached_file, 'rb') as f:
123113
hd = hashlib.sha256(f.read()).hexdigest()
124-
if hd[: len(hash_prefix)] != hash_prefix:
114+
if hd[:len(hash_prefix)] != hash_prefix:
125115
return False
126116
return True
127117
return False
@@ -131,8 +121,7 @@ def has_hf_hub(necessary: bool = False):
131121
if not _has_hf_hub and necessary:
132122
# if no HF Hub module installed, and it is necessary to continue, raise error
133123
raise RuntimeError(
134-
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.'
135-
)
124+
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
136125
return _has_hf_hub
137126

138127

@@ -152,9 +141,9 @@ def load_cfg_from_json(json_file: Union[str, Path]):
152141

153142

154143
def download_from_hf(
155-
model_id: str,
156-
filename: str,
157-
cache_dir: Optional[Union[str, Path]] = None,
144+
model_id: str,
145+
filename: str,
146+
cache_dir: Optional[Union[str, Path]] = None,
158147
):
159148
hf_model_id, hf_revision = hf_split(model_id)
160149
return hf_hub_download(
@@ -166,8 +155,8 @@ def download_from_hf(
166155

167156

168157
def _parse_model_cfg(
169-
cfg: Dict[str, Any],
170-
extra_fields: Dict[str, Any],
158+
cfg: Dict[str, Any],
159+
extra_fields: Dict[str, Any],
171160
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
172161
""""""
173162
# legacy "single‑dict" → split
@@ -178,7 +167,7 @@ def _parse_model_cfg(
178167
"num_features": pretrained_cfg.pop("num_features", None),
179168
"pretrained_cfg": pretrained_cfg,
180169
}
181-
if "labels" in pretrained_cfg: # rename ‑‑> label_names
170+
if "labels" in pretrained_cfg: # rename ‑‑> label_names
182171
pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")
183172

184173
pretrained_cfg = cfg["pretrained_cfg"]
@@ -198,8 +187,8 @@ def _parse_model_cfg(
198187

199188

200189
def load_model_config_from_hf(
201-
model_id: str,
202-
cache_dir: Optional[Union[str, Path]] = None,
190+
model_id: str,
191+
cache_dir: Optional[Union[str, Path]] = None,
203192
):
204193
"""Original HF‑Hub loader (unchanged download, shared parsing)."""
205194
assert has_hf_hub(True)
@@ -209,7 +198,7 @@ def load_model_config_from_hf(
209198

210199

211200
def load_model_config_from_path(
212-
model_path: Union[str, Path],
201+
model_path: Union[str, Path],
213202
):
214203
"""Load from ``<model_path>/config.json`` on the local filesystem."""
215204
model_path = Path(model_path)
@@ -222,10 +211,10 @@ def load_model_config_from_path(
222211

223212

224213
def load_state_dict_from_hf(
225-
model_id: str,
226-
filename: str = HF_WEIGHTS_NAME,
227-
weights_only: bool = False,
228-
cache_dir: Optional[Union[str, Path]] = None,
214+
model_id: str,
215+
filename: str = HF_WEIGHTS_NAME,
216+
weights_only: bool = False,
217+
cache_dir: Optional[Union[str, Path]] = None,
229218
):
230219
assert has_hf_hub(True)
231220
hf_model_id, hf_revision = hf_split(model_id)
@@ -242,8 +231,7 @@ def load_state_dict_from_hf(
242231
)
243232
_logger.info(
244233
f"[{model_id}] Safe alternative available for '{filename}' "
245-
f"(as '{safe_filename}'). Loading weights using safetensors."
246-
)
234+
f"(as '{safe_filename}'). Loading weights using safetensors.")
247235
return safetensors.torch.load_file(cached_safe_file, device="cpu")
248236
except EntryNotFoundError:
249237
pass
@@ -275,10 +263,9 @@ def load_state_dict_from_hf(
275263
)
276264
_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')
277265

278-
279266
def load_state_dict_from_path(
280-
path: str,
281-
weights_only: bool = False,
267+
path: str,
268+
weights_only: bool = False,
282269
):
283270
found_file = None
284271
for fname in _PREFERRED_FILES:
@@ -293,7 +280,10 @@ def load_state_dict_from_path(
293280
files = sorted(path.glob(f"*{ext}"))
294281
if files:
295282
if len(files) > 1:
296-
logging.warning(f"Multiple {ext} checkpoints in {path}: {names}. " f"Using '{files[0].name}'.")
283+
logging.warning(
284+
f"Multiple {ext} checkpoints in {path}: {names}. "
285+
f"Using '{files[0].name}'."
286+
)
297287
found_file = files[0]
298288

299289
if not found_file:
@@ -307,10 +297,10 @@ def load_state_dict_from_path(
307297

308298

309299
def load_custom_from_hf(
310-
model_id: str,
311-
filename: str,
312-
model: torch.nn.Module,
313-
cache_dir: Optional[Union[str, Path]] = None,
300+
model_id: str,
301+
filename: str,
302+
model: torch.nn.Module,
303+
cache_dir: Optional[Union[str, Path]] = None,
314304
):
315305
assert has_hf_hub(True)
316306
hf_model_id, hf_revision = hf_split(model_id)
@@ -324,7 +314,10 @@ def load_custom_from_hf(
324314

325315

326316
def save_config_for_hf(
327-
model: torch.nn.Module, config_path: str, model_config: Optional[dict] = None, model_args: Optional[dict] = None
317+
model: torch.nn.Module,
318+
config_path: str,
319+
model_config: Optional[dict] = None,
320+
model_args: Optional[dict] = None
328321
):
329322
model_config = model_config or {}
330323
hf_config = {}
@@ -343,8 +336,7 @@ def save_config_for_hf(
343336
if 'labels' in model_config:
344337
_logger.warning(
345338
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
346-
" Renaming provided 'labels' field to 'label_names'."
347-
)
339+
" Renaming provided 'labels' field to 'label_names'.")
348340
model_config.setdefault('label_names', model_config.pop('labels'))
349341

350342
label_names = model_config.pop('label_names', None)
@@ -371,11 +363,11 @@ def save_config_for_hf(
371363

372364

373365
def save_for_hf(
374-
model: torch.nn.Module,
375-
save_directory: str,
376-
model_config: Optional[dict] = None,
377-
model_args: Optional[dict] = None,
378-
safe_serialization: Union[bool, Literal["both"]] = False,
366+
model: torch.nn.Module,
367+
save_directory: str,
368+
model_config: Optional[dict] = None,
369+
model_args: Optional[dict] = None,
370+
safe_serialization: Union[bool, Literal["both"]] = False,
379371
):
380372
assert has_hf_hub(True)
381373
save_directory = Path(save_directory)
@@ -399,18 +391,18 @@ def save_for_hf(
399391

400392

401393
def push_to_hf_hub(
402-
model: torch.nn.Module,
403-
repo_id: str,
404-
commit_message: str = 'Add model',
405-
token: Optional[str] = None,
406-
revision: Optional[str] = None,
407-
private: bool = False,
408-
create_pr: bool = False,
409-
model_config: Optional[dict] = None,
410-
model_card: Optional[dict] = None,
411-
model_args: Optional[dict] = None,
412-
task_name: str = 'image-classification',
413-
safe_serialization: Union[bool, Literal["both"]] = 'both',
394+
model: torch.nn.Module,
395+
repo_id: str,
396+
commit_message: str = 'Add model',
397+
token: Optional[str] = None,
398+
revision: Optional[str] = None,
399+
private: bool = False,
400+
create_pr: bool = False,
401+
model_config: Optional[dict] = None,
402+
model_card: Optional[dict] = None,
403+
model_args: Optional[dict] = None,
404+
task_name: str = 'image-classification',
405+
safe_serialization: Union[bool, Literal["both"]] = 'both',
414406
):
415407
"""
416408
Arguments:
@@ -460,9 +452,9 @@ def push_to_hf_hub(
460452

461453

462454
def generate_readme(
463-
model_card: dict,
464-
model_name: str,
465-
task_name: str = 'image-classification',
455+
model_card: dict,
456+
model_name: str,
457+
task_name: str = 'image-classification',
466458
):
467459
tags = model_card.get('tags', None) or [task_name, 'timm', 'transformers']
468460
readme_text = "---\n"

0 commit comments

Comments
 (0)