|
15 | 15 |
|
16 | 16 | import torch |
17 | 17 | import torch.nn as nn |
| 18 | +from huggingface_hub import hf_hub_download |
18 | 19 |
|
19 | 20 | from monai.utils import optional_import |
20 | 21 | from monai.utils.enums import StrEnum |
21 | | -from huggingface_hub import hf_hub_download |
22 | 22 |
|
23 | 23 | LPIPS, _ = optional_import("lpips", name="LPIPS") |
24 | 24 | torchvision, _ = optional_import("torchvision") |
25 | 25 |
|
26 | 26 |
|
27 | 27 | class PercetualNetworkType(StrEnum): |
28 | | - """Types of neural networks that are supported by perceptua loss. |
29 | | - """ |
| 28 | + """Types of neural networks that are supported by perceptua loss.""" |
30 | 29 |
|
31 | 30 | alex = "alex" |
32 | 31 | vgg = "vgg" |
@@ -116,8 +115,7 @@ def __init__( |
116 | 115 | # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used. |
117 | 116 | if spatial_dims == 3 and is_fake_3d is False: |
118 | 117 | self.perceptual_function = MedicalNetPerceptualSimilarity( |
119 | | - net=network_type, verbose=False, channel_wise=channel_wise, |
120 | | - cache_dir=cache_dir |
| 118 | + net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir |
121 | 119 | ) |
122 | 120 | elif "radimagenet_" in network_type: |
123 | 121 | self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) |
@@ -214,12 +212,17 @@ class MedicalNetPerceptualSimilarity(nn.Module): |
214 | 212 | """ |
215 | 213 |
|
216 | 214 | def __init__( |
217 | | - self, net: str = "medicalnet_resnet_10_23datasets", verbose: bool = False, channel_wise: bool = False, |
| 215 | + self, |
| 216 | + net: str = "medicalnet_resnet_10_23datasets", |
| 217 | + verbose: bool = False, |
| 218 | + channel_wise: bool = False, |
218 | 219 | cache_dir: str | None = None, |
219 | 220 | ) -> None: |
220 | 221 | super().__init__() |
221 | 222 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True |
222 | | - self.model = torch.hub.load("Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir) |
| 223 | + self.model = torch.hub.load( |
| 224 | + "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir |
| 225 | + ) |
223 | 226 | self.eval() |
224 | 227 |
|
225 | 228 | self.channel_wise = channel_wise |
@@ -305,12 +308,9 @@ class RadImageNetPerceptualSimilarity(nn.Module): |
305 | 308 | verbose: if false, mute messages from torch Hub load function. |
306 | 309 | """ |
307 | 310 |
|
308 | | - def __init__(self, net: str = "radimagenet_resnet50", |
309 | | - verbose: bool = False, |
310 | | - cache_dir: str | None = None) -> None: |
| 311 | + def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: |
311 | 312 | super().__init__() |
312 | | - self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, |
313 | | - cache_dir=cache_dir) |
| 313 | + self.model = torch.hub.load("Project-MONAI/perceptual-models", model=net, verbose=verbose, cache_dir=cache_dir) |
314 | 314 | self.eval() |
315 | 315 |
|
316 | 316 | for param in self.parameters(): |
|
0 commit comments