Skip to content

Commit 25c4790

Browse files
committed
Adding new approach
1 parent 288e5b5 commit 25c4790

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

timm/models/_hub.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -535,17 +535,17 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
535535
yield filename[:-4] + ".safetensors"
536536

537537

538-
def _get_license_from_hf_hub(model_id: str | None, hf_hub_id: str | None) -> str | None:
538+
def _get_license_from_hf_hub(hf_hub_id: str | None) -> str | None:
539539
"""Retrieve license information for a model from Hugging Face Hub.
540540
541541
Fetches the license field from the model card metadata on Hugging Face Hub
542-
for the specified model. Returns None if the model is not found, if
543-
huggingface_hub is not installed, or if the model is marked as "untrained".
542+
for the specified model. This function is called lazily when the license
543+
attribute is accessed on PretrainedCfg objects that don't have an explicit
544+
license set.
544545
545546
Args:
546-
model_id: The model identifier/name. In the case of None we assume an untrained model.
547-
hf_hub_id: The Hugging Face Hub organization/user ID. If it is None,
548-
we will return None as we cannot infer the license terms.
547+
hf_hub_id: The Hugging Face Hub model ID (e.g., 'organization/model').
548+
If None or empty, returns None as license cannot be determined.
549549
550550
Returns:
551551
The license string in lowercase if found, None otherwise.
@@ -559,17 +559,17 @@ def _get_license_from_hf_hub(model_id: str | None, hf_hub_id: str | None) -> str
559559
_logger.warning(msg=msg)
560560
return None
561561

562-
if not (model_id and hf_hub_id):
562+
if hf_hub_id is None or hf_hub_id == "timm/":
563563
return None
564564

565-
repo_id: str = hf_hub_id + model_id
566-
567565
try:
568-
info = model_info(repo_id=repo_id)
566+
info = model_info(repo_id=hf_hub_id)
569567

570568
except RepositoryNotFoundError:
571569
# TODO: any wish what happens here? @rwightman
572-
print(repo_id)
570+
return None
571+
572+
except Exception as _:
573573
return None
574574

575575
license = info.card_data.get("license").lower() if info.card_data else None

timm/models/_pretrained.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ class PretrainedCfg:
5858
def has_weights(self):
5959
return self.url or self.file or self.hf_hub_id
6060

61+
def __getattribute__(self, name):
62+
if name == 'license': # Intercept license access to set it in case it was not set anywhere else.
63+
license_value = super().__getattribute__('license')
64+
65+
if license_value is None:
66+
from ._hub import _get_license_from_hf_hub
67+
license_value = _get_license_from_hf_hub(hf_hub_id=self.hf_hub_id)
68+
69+
self.license = license_value
70+
71+
return license_value
72+
73+
return super().__getattribute__(name)
74+
6175
def to_dict(self, remove_source=False, remove_null=True):
6276
return filter_pretrained_cfg(
6377
asdict(self),

0 commit comments

Comments
 (0)