Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 110 additions & 9 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
import torch
import torch.nn as nn

from monai.apps.utils import check_hash
from monai.networks.nets.resnet import ResNetFeatures
from monai.utils import optional_import
from monai.utils.enums import StrEnum

LPIPS, _ = optional_import("lpips", name="LPIPS")
torchvision, _ = optional_import("torchvision")
hf_hub_download, _ = optional_import("huggingface_hub", name="hf_hub_download")


class PercetualNetworkType(StrEnum):
Expand Down Expand Up @@ -86,7 +89,7 @@ def __init__(

if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type:
raise ValueError(
"MedicalNet networks are only compatible with ``spatial_dims=3``."
"MedicalNet networks are only compatible with ``spatial_dims=3``. "
"Argument is_fake_3d must be set to False."
)

Expand Down Expand Up @@ -193,13 +196,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
class MedicalNetPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer
Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from
"Warvito/MedicalNet-models".
Learning for 3D Medical Image Analysis". This class downloads the pretrained weights from the Hugging Face
repository "MONAI/checkpoints".

Args:
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
verbose: if false, mute messages from torch Hub load function.
verbose: if false, mute messages from model loading (currently unused).
channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
Defaults to ``False``.
"""
Expand All @@ -208,15 +211,57 @@ def __init__(
self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False
) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True)
# Load model from Hugging Face
self.model = self._load_medicalnet_from_hf(net)
self.eval()

self.channel_wise = channel_wise

for param in self.parameters():
param.requires_grad = False

def _load_medicalnet_from_hf(self, net: str) -> nn.Module:
"""Load MedicalNet model from Hugging Face hub."""
# Map network names to model names and file names
model_mapping = {
"medicalnet_resnet10_23datasets": (
"resnet10",
"medicalnet_resnet_10_23dataset.pth",
"afa8055f3e47f4a18239495d92a7abc587902c69c31c743de2b2784653b72605",
),
"medicalnet_resnet50_23datasets": (
"resnet50",
"medicalnet_resnet50_23datasets.pth",
"ff48a62219073fb977fd3f4ddfb8dc1367f0ec156c8d6f6c37e205bd683a246e",
),
}

if net not in model_mapping:
raise ValueError(f"Unsupported network: {net}. Choose from {list(model_mapping.keys())}")

model_name, filename, hash_val = model_mapping[net]

# Download weights from Hugging Face
pretrained_path = hf_hub_download(repo_id="MONAI/checkpoints", filename=filename)

if not check_hash(pretrained_path, hash_val, hash_type="sha256"):
raise RuntimeError(f"Hash mismatch for file {filename}.")

# Create model using MONAI's ResNetFeatures (which returns feature maps, not final classification)
model = ResNetFeatures(model_name=model_name, pretrained=False, spatial_dims=3, in_channels=1)

# Load the pretrained weights
checkpoint = torch.load(pretrained_path, map_location="cpu", weights_only=True)
state_dict = checkpoint.get("state_dict", checkpoint)

# Remove 'module.' prefix if present
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

model.load_state_dict(state_dict, strict=False)

# Wrap to return only the last feature map (pooled)
return MedicalNetFeaturesWrapper(model)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute perceptual loss using MedicalNet 3D networks. The input and target tensors are inputted in the
Expand Down Expand Up @@ -267,6 +312,28 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return results


class MedicalNetFeaturesWrapper(nn.Module):
"""Wrapper to extract and pool the last feature map from ResNetFeatures."""

def __init__(self, resnet_features_model):
super().__init__()
self.model = resnet_features_model

def forward(self, x):
# ResNetFeatures returns a list of feature maps at different scales
# We want the last one (highest level features) and pool it
features = self.model(x)

# Get the last feature map
last_features = features[-1]

# Apply average pooling to get final pooled features
# The avgpool layer is already part of the ResNet model
pooled = self.model.avgpool(last_features)

return pooled


def spatial_average_3d(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
return x.mean([2, 3, 4], keepdim=keepdim)

Expand All @@ -287,22 +354,56 @@ class RadImageNetPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et
al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class
uses torch Hub to download the networks from "Warvito/radimagenet-models".
downloads the pretrained weights from the Hugging Face repository "MONAI/checkpoints".

Args:
net: {``"radimagenet_resnet50"``}
Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``.
verbose: if false, mute messages from torch Hub load function.
verbose: if false, mute messages from model loading (currently unused).
"""

def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
super().__init__()
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True)
self.model = self._load_radimagenet_from_hf(net)
self.eval()

for param in self.parameters():
param.requires_grad = False

def _load_radimagenet_from_hf(self, net: str) -> nn.Module:
"""Load RadImageNet model from Hugging Face hub."""
if net != "radimagenet_resnet50":
raise ValueError(f"Unsupported network: {net}. Only 'radimagenet_resnet50' is supported.")

filename = "RadImageNet-ResNet50_notop.pth"
hash_val = "2457479b254569e5a81ba48fee6c5b2b84b7a729e507aaa2466101aedb8e5c37"
# Download weights from Hugging Face
pretrained_path = hf_hub_download(repo_id="MONAI/checkpoints", filename=filename)

if not check_hash(pretrained_path, hash_val, hash_type="sha256"):
raise RuntimeError(f"Hash mismatch for file {filename}.")

# Create ResNet50 model using torchvision
model = torchvision.models.resnet50(weights=None)

# Remove the final classification layer (we only need features)
model = nn.Sequential(*list(model.children())[:-1])

# Load the pretrained weights
state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True)

# The state dict might have a different structure, adjust as needed
# Try to load with strict=False to handle any mismatches
try:
model.load_state_dict(state_dict, strict=False)
except Exception:
# If the state dict is wrapped, unwrap it
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict, strict=False)

return model

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at
Expand Down
74 changes: 74 additions & 0 deletions tests/losses/test_huggingface_loading_perceptual_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest
from unittest.mock import patch

import torch

from monai.losses import PerceptualLoss
from monai.utils import optional_import
from tests.test_utils import skip_if_downloading_fails, skip_if_quick

_, has_torchvision = optional_import("torchvision")


@unittest.skipUnless(has_torchvision, "Requires torchvision")
@skip_if_quick
class TestHuggingFaceLoadingPerceptualLoss(unittest.TestCase):
def test_medicalnet_resnet10_loading(self):
"""Test MedicalNet ResNet10 loading from Hugging Face."""
with skip_if_downloading_fails():
loss = PerceptualLoss(spatial_dims=3, network_type="medicalnet_resnet10_23datasets", is_fake_3d=False)

input_tensor = torch.randn(1, 1, 32, 32, 32)
target_tensor = torch.randn(1, 1, 32, 32, 32)
result = loss(input_tensor, target_tensor)

self.assertEqual(result.shape, torch.Size([]))
self.assertIsInstance(result.item(), float)

def test_medicalnet_resnet50_loading(self):
"""Test MedicalNet ResNet50 loading from Hugging Face."""
with skip_if_downloading_fails():
loss = PerceptualLoss(spatial_dims=3, network_type="medicalnet_resnet50_23datasets", is_fake_3d=False)

input_tensor = torch.randn(1, 1, 32, 32, 32)
target_tensor = torch.randn(1, 1, 32, 32, 32)
result = loss(input_tensor, target_tensor)

self.assertEqual(result.shape, torch.Size([]))
self.assertIsInstance(result.item(), float)

def test_radimagenet_loading(self):
"""Test RadImageNet ResNet50 loading from Hugging Face."""
with skip_if_downloading_fails():
loss = PerceptualLoss(spatial_dims=2, network_type="radimagenet_resnet50")

input_tensor = torch.randn(1, 1, 64, 64)
target_tensor = torch.randn(1, 1, 64, 64)
result = loss(input_tensor, target_tensor)

self.assertEqual(result.shape, torch.Size([]))
self.assertIsInstance(result.item(), float)

def test_checksum_failure(self):
"""Test that a RuntimeError is raised when checksum verification fails."""
with patch("monai.losses.perceptual.hf_hub_download", return_value="/tmp/dummy_path"):
with patch("monai.losses.perceptual.check_hash", return_value=False):
with self.assertRaisesRegex(RuntimeError, "Hash mismatch for file"):
PerceptualLoss(spatial_dims=3, network_type="medicalnet_resnet10_23datasets", is_fake_3d=False)


if __name__ == "__main__":
unittest.main()