Skip to content

Commit fa0639b

Browse files
author
Virginia Fernandez
committed
Fixes
1 parent 0aeb4d9 commit fa0639b

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

monai/losses/perceptual.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,17 @@
1515

1616
import torch
1717
import torch.nn as nn
18+
from huggingface_hub import hf_hub_download
1819

1920
from monai.utils import optional_import
2021
from monai.utils.enums import StrEnum
21-
from huggingface_hub import hf_hub_download
2222

2323
LPIPS, _ = optional_import("lpips", name="LPIPS")
2424
torchvision, _ = optional_import("torchvision")
2525

2626

2727
class PercetualNetworkType(StrEnum):
28-
"""Types of neural networks that are supported by perceptua loss.
29-
"""
28+
"""Types of neural networks that are supported by perceptua loss."""
3029

3130
alex = "alex"
3231
vgg = "vgg"
@@ -116,8 +115,7 @@ def __init__(
116115
# If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used.
117116
if spatial_dims == 3 and is_fake_3d is False:
118117
self.perceptual_function = MedicalNetPerceptualSimilarity(
119-
net=network_type, verbose=False, channel_wise=channel_wise,
120-
cache_dir=cache_dir
118+
net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir
121119
)
122120
elif "radimagenet_" in network_type:
123121
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
@@ -214,12 +212,17 @@ class MedicalNetPerceptualSimilarity(nn.Module):
214212
"""
215213

216214
def __init__(
217-
self, net: str = "medicalnet_resnet_10_23datasets", verbose: bool = False, channel_wise: bool = False,
215+
self,
216+
net: str = "medicalnet_resnet_10_23datasets",
217+
verbose: bool = False,
218+
channel_wise: bool = False,
218219
cache_dir: str | None = None,
219220
) -> None:
220221
super().__init__()
221222
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
222-
self.model = torch.hub.load("Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir)
223+
self.model = torch.hub.load(
224+
"Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir
225+
)
223226
self.eval()
224227

225228
self.channel_wise = channel_wise
@@ -305,12 +308,9 @@ class RadImageNetPerceptualSimilarity(nn.Module):
305308
verbose: if false, mute messages from torch Hub load function.
306309
"""
307310

308-
def __init__(self, net: str = "radimagenet_resnet50",
309-
verbose: bool = False,
310-
cache_dir: str | None = None) -> None:
311+
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None:
311312
super().__init__()
312-
self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose,
313-
cache_dir=cache_dir)
313+
self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir)
314314
self.eval()
315315

316316
for param in self.parameters():

0 commit comments

Comments
 (0)