diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index a409a662..842719b0 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -37,13 +37,23 @@ def get_encoder(name, in_channels=3, depth=5, weights=None): - Encoder = encoders[name]["encoder"] + + try: + Encoder = encoders[name]["encoder"] + except KeyError: + raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys()))) + params = encoders[name]["params"] params.update(depth=depth) encoder = Encoder(**params) if weights is not None: - settings = encoders[name]["pretrained_settings"][weights] + try: + settings = encoders[name]["pretrained_settings"][weights] + except KeyError: + raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Avaliable options are: {}".format( + weights, name, list(encoders[name]["pretrained_settings"].keys()), + )) encoder.load_state_dict(model_zoo.load_url(settings["url"])) encoder.set_in_channels(in_channels) diff --git a/segmentation_models_pytorch/pan/model.py b/segmentation_models_pytorch/pan/model.py index 74a52989..430657b1 100644 --- a/segmentation_models_pytorch/pan/model.py +++ b/segmentation_models_pytorch/pan/model.py @@ -44,7 +44,7 @@ class PAN(SegmentationModel): def __init__( self, encoder_name: str = "resnet34", - encoder_weights: str = "imagenet", + encoder_weights: Optional[str] = "imagenet", encoder_dilation: bool = True, decoder_channels: int = 32, in_channels: int = 3, diff --git a/segmentation_models_pytorch/unet/model.py b/segmentation_models_pytorch/unet/model.py index a45753b9..7009cf2d 100644 --- a/segmentation_models_pytorch/unet/model.py +++ b/segmentation_models_pytorch/unet/model.py @@ -51,7 +51,7 @@ def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, - encoder_weights: str = "imagenet", + encoder_weights: Optional[str] = "imagenet", decoder_use_batchnorm: bool = True, decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, diff --git a/segmentation_models_pytorch/unetplusplus/model.py b/segmentation_models_pytorch/unetplusplus/model.py index 51a14a5b..3353a317 100644 --- a/segmentation_models_pytorch/unetplusplus/model.py +++ b/segmentation_models_pytorch/unetplusplus/model.py @@ -51,7 +51,7 @@ def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, - encoder_weights: str = "imagenet", + encoder_weights: Optional[str] = "imagenet", decoder_use_batchnorm: bool = True, decoder_channels: List[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None,