diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 2ae03bc8dc..387155dfb9 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -16,11 +16,14 @@ import torch import torch.nn as nn +from monai.apps.utils import check_hash +from monai.networks.nets.resnet import ResNetFeatures from monai.utils import optional_import from monai.utils.enums import StrEnum LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") +hf_hub_download, _ = optional_import("huggingface_hub", name="hf_hub_download") class PercetualNetworkType(StrEnum): @@ -86,7 +89,7 @@ def __init__( if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: raise ValueError( - "MedicalNet networks are only compatible with ``spatial_dims=3``." + "MedicalNet networks are only compatible with ``spatial_dims=3``. " "Argument is_fake_3d must be set to False." ) @@ -193,13 +196,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: class MedicalNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer - Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from - "Warvito/MedicalNet-models". + Learning for 3D Medical Image Analysis". This class downloads the pretrained weights from the Hugging Face + repository "MONAI/checkpoints". Args: net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``. - verbose: if false, mute messages from torch Hub load function. + verbose: if false, mute messages from model loading (currently unused). channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels. Defaults to ``False``. """ @@ -208,8 +211,8 @@ def __init__( self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False ) -> None: super().__init__() - torch.hub._validate_not_a_forked_repo = lambda a, b, c: True - self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True) + # Load model from Hugging Face + self.model = self._load_medicalnet_from_hf(net) self.eval() self.channel_wise = channel_wise @@ -217,6 +220,48 @@ def __init__( for param in self.parameters(): param.requires_grad = False + def _load_medicalnet_from_hf(self, net: str) -> nn.Module: + """Load MedicalNet model from Hugging Face hub.""" + # Map network names to model names and file names + model_mapping = { + "medicalnet_resnet10_23datasets": ( + "resnet10", + "medicalnet_resnet_10_23dataset.pth", + "afa8055f3e47f4a18239495d92a7abc587902c69c31c743de2b2784653b72605", + ), + "medicalnet_resnet50_23datasets": ( + "resnet50", + "medicalnet_resnet50_23datasets.pth", + "ff48a62219073fb977fd3f4ddfb8dc1367f0ec156c8d6f6c37e205bd683a246e", + ), + } + + if net not in model_mapping: + raise ValueError(f"Unsupported network: {net}. Choose from {list(model_mapping.keys())}") + + model_name, filename, hash_val = model_mapping[net] + + # Download weights from Hugging Face + pretrained_path = hf_hub_download(repo_id="MONAI/checkpoints", filename=filename) + + if not check_hash(pretrained_path, hash_val, hash_type="sha256"): + raise RuntimeError(f"Hash mismatch for file {filename}.") + + # Create model using MONAI's ResNetFeatures (which returns feature maps, not final classification) + model = ResNetFeatures(model_name=model_name, pretrained=False, spatial_dims=3, in_channels=1) + + # Load the pretrained weights + checkpoint = torch.load(pretrained_path, map_location="cpu", weights_only=True) + state_dict = checkpoint.get("state_dict", checkpoint) + + # Remove 'module.' prefix if present + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + + model.load_state_dict(state_dict, strict=False) + + # Wrap to return only the last feature map (pooled) + return MedicalNetFeaturesWrapper(model) + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Compute perceptual loss using MedicalNet 3D networks. The input and target tensors are inputted in the @@ -267,6 +312,28 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return results +class MedicalNetFeaturesWrapper(nn.Module): + """Wrapper to extract and pool the last feature map from ResNetFeatures.""" + + def __init__(self, resnet_features_model): + super().__init__() + self.model = resnet_features_model + + def forward(self, x): + # ResNetFeatures returns a list of feature maps at different scales + # We want the last one (highest level features) and pool it + features = self.model(x) + + # Get the last feature map + last_features = features[-1] + + # Apply average pooling to get final pooled features + # The avgpool layer is already part of the ResNet model + pooled = self.model.avgpool(last_features) + + return pooled + + def spatial_average_3d(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: return x.mean([2, 3, 4], keepdim=keepdim) @@ -287,22 +354,56 @@ class RadImageNetPerceptualSimilarity(nn.Module): """ Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class - uses torch Hub to download the networks from "Warvito/radimagenet-models". + downloads the pretrained weights from the Hugging Face repository "MONAI/checkpoints". Args: net: {``"radimagenet_resnet50"``} Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``. - verbose: if false, mute messages from torch Hub load function. + verbose: if false, mute messages from model loading (currently unused). """ def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None: super().__init__() - self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True) + self.model = self._load_radimagenet_from_hf(net) self.eval() for param in self.parameters(): param.requires_grad = False + def _load_radimagenet_from_hf(self, net: str) -> nn.Module: + """Load RadImageNet model from Hugging Face hub.""" + if net != "radimagenet_resnet50": + raise ValueError(f"Unsupported network: {net}. Only 'radimagenet_resnet50' is supported.") + + filename = "RadImageNet-ResNet50_notop.pth" + hash_val = "2457479b254569e5a81ba48fee6c5b2b84b7a729e507aaa2466101aedb8e5c37" + # Download weights from Hugging Face + pretrained_path = hf_hub_download(repo_id="MONAI/checkpoints", filename=filename) + + if not check_hash(pretrained_path, hash_val, hash_type="sha256"): + raise RuntimeError(f"Hash mismatch for file {filename}.") + + # Create ResNet50 model using torchvision + model = torchvision.models.resnet50(weights=None) + + # Remove the final classification layer (we only need features) + model = nn.Sequential(*list(model.children())[:-1]) + + # Load the pretrained weights + state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True) + + # The state dict might have a different structure, adjust as needed + # Try to load with strict=False to handle any mismatches + try: + model.load_state_dict(state_dict, strict=False) + except Exception: + # If the state dict is wrapped, unwrap it + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at diff --git a/tests/losses/test_huggingface_loading_perceptual_loss.py b/tests/losses/test_huggingface_loading_perceptual_loss.py new file mode 100644 index 0000000000..e42fc766f9 --- /dev/null +++ b/tests/losses/test_huggingface_loading_perceptual_loss.py @@ -0,0 +1,74 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest.mock import patch + +import torch + +from monai.losses import PerceptualLoss +from monai.utils import optional_import +from tests.test_utils import skip_if_downloading_fails, skip_if_quick + +_, has_torchvision = optional_import("torchvision") + + +@unittest.skipUnless(has_torchvision, "Requires torchvision") +@skip_if_quick +class TestHuggingFaceLoadingPerceptualLoss(unittest.TestCase): + def test_medicalnet_resnet10_loading(self): + """Test MedicalNet ResNet10 loading from Hugging Face.""" + with skip_if_downloading_fails(): + loss = PerceptualLoss(spatial_dims=3, network_type="medicalnet_resnet10_23datasets", is_fake_3d=False) + + input_tensor = torch.randn(1, 1, 32, 32, 32) + target_tensor = torch.randn(1, 1, 32, 32, 32) + result = loss(input_tensor, target_tensor) + + self.assertEqual(result.shape, torch.Size([])) + self.assertIsInstance(result.item(), float) + + def test_medicalnet_resnet50_loading(self): + """Test MedicalNet ResNet50 loading from Hugging Face.""" + with skip_if_downloading_fails(): + loss = PerceptualLoss(spatial_dims=3, network_type="medicalnet_resnet50_23datasets", is_fake_3d=False) + + input_tensor = torch.randn(1, 1, 32, 32, 32) + target_tensor = torch.randn(1, 1, 32, 32, 32) + result = loss(input_tensor, target_tensor) + + self.assertEqual(result.shape, torch.Size([])) + self.assertIsInstance(result.item(), float) + + def test_radimagenet_loading(self): + """Test RadImageNet ResNet50 loading from Hugging Face.""" + with skip_if_downloading_fails(): + loss = PerceptualLoss(spatial_dims=2, network_type="radimagenet_resnet50") + + input_tensor = torch.randn(1, 1, 64, 64) + target_tensor = torch.randn(1, 1, 64, 64) + result = loss(input_tensor, target_tensor) + + self.assertEqual(result.shape, torch.Size([])) + self.assertIsInstance(result.item(), float) + + def test_checksum_failure(self): + """Test that a RuntimeError is raised when checksum verification fails.""" + with patch("monai.losses.perceptual.hf_hub_download", return_value="/tmp/dummy_path"): + with patch("monai.losses.perceptual.check_hash", return_value=False): + with self.assertRaisesRegex(RuntimeError, "Hash mismatch for file"): + PerceptualLoss(spatial_dims=3, network_type="medicalnet_resnet10_23datasets", is_fake_3d=False) + + +if __name__ == "__main__": + unittest.main()