diff --git a/README.md b/README.md index 95d24ee2..a695dbd5 100644 --- a/README.md +++ b/README.md @@ -394,6 +394,19 @@ Note: In the official github repo the s0 variant has additional num_conv_branche +
+SAM +
+ +| Encoder | Weights | Params, M | +|-----------|:--------:|:---------:| +| sam-vit_b | sa-1b | 91M | +| sam-vit_l | sa-1b | 308M | +| sam-vit_h | sa-1b | 636M | + +
+
+ \* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)). diff --git a/docs/encoders.rst b/docs/encoders.rst index d64607b8..55946e6e 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -361,3 +361,16 @@ MobileOne +-----------------+----------+------------+ | mobileone\_s4 | imagenet | 13.6M | +-----------------+----------+------------+ + +SAM +~~~~~~~~~~~~~~~~~~~~~ + ++-----------------+----------+------------+ +| Encoder | Weights | Params, M | ++=================+==========+============+ +| sam-vit_b | sa-1b | 91M | ++-----------------+----------+------------+ +| sam-vit_l | sa-1b | 308M | ++-----------------+----------+------------+ +| sam-vit_h | sa-1b | 636M | ++-----------------+----------+------------+ diff --git a/docs/models.rst b/docs/models.rst index 47de61ee..a5ab52c1 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -36,5 +36,3 @@ DeepLabV3 DeepLabV3+ ~~~~~~~~~~ .. autoclass:: segmentation_models_pytorch.DeepLabV3Plus - - diff --git a/requirements.txt b/requirements.txt index 5f1a53ac..9e6cd5a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torchvision>=0.5.0 pretrainedmodels==0.7.4 efficientnet-pytorch==0.7.1 timm==0.9.7 - +segment-anything-py==1.0 tqdm pillow six diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 7551153f..635f44b4 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -4,6 +4,7 @@ from .resnet import resnet_encoders from .dpn import dpn_encoders +from .sam import sam_vit_encoders, SamVitEncoder from .vgg import vgg_encoders from .senet import senet_encoders from .densenet import densenet_encoders @@ -46,6 +47,34 @@ encoders.update(timm_gernet_encoders) encoders.update(mix_transformer_encoders) encoders.update(mobileone_encoders) +encoders.update(sam_vit_encoders) + + +def get_pretrained_settings(encoders: dict, encoder_name: str, weights: str) -> dict: + """Get pretrained settings for encoder from encoders collection. + + Args: + encoders: collection of encoders + encoder_name: name of encoder in collection + weights: one of ``None`` (random initialization), ``imagenet`` or other pretrained settings + + Returns: + pretrained settings for encoder + + Raises: + KeyError: in case of wrong encoder name or pretrained settings name + """ + try: + settings = encoders[encoder_name]["pretrained_settings"][weights] + except KeyError: + raise KeyError( + "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( + weights, + encoder_name, + list(encoders[encoder_name]["pretrained_settings"].keys()), + ) + ) + return settings def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): @@ -69,19 +98,11 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, ** params = encoders[name]["params"] params.update(depth=depth) + params.update(kwargs) encoder = Encoder(**params) if weights is not None: - try: - settings = encoders[name]["pretrained_settings"][weights] - except KeyError: - raise KeyError( - "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( - weights, - name, - list(encoders[name]["pretrained_settings"].keys()), - ) - ) + settings = get_pretrained_settings(encoders, name, weights) encoder.load_state_dict(model_zoo.load_url(settings["url"])) encoder.set_in_channels(in_channels, pretrained=weights is not None) diff --git a/segmentation_models_pytorch/encoders/_base.py b/segmentation_models_pytorch/encoders/_base.py index aab838f1..fee8d177 100644 --- a/segmentation_models_pytorch/encoders/_base.py +++ b/segmentation_models_pytorch/encoders/_base.py @@ -1,8 +1,3 @@ -import torch -import torch.nn as nn -from typing import List -from collections import OrderedDict - from . import _utils as utils diff --git a/segmentation_models_pytorch/encoders/sam.py b/segmentation_models_pytorch/encoders/sam.py new file mode 100644 index 00000000..aac722ba --- /dev/null +++ b/segmentation_models_pytorch/encoders/sam.py @@ -0,0 +1,164 @@ +import math +import warnings +from typing import Mapping, Any + +import torch +from segment_anything.modeling import ImageEncoderViT +from torch import nn +from segment_anything.modeling.common import LayerNorm2d + +from segmentation_models_pytorch.encoders._base import EncoderMixin + + +class SamVitEncoder(EncoderMixin, ImageEncoderViT): + def __init__(self, **kwargs): + self._vit_depth = kwargs.pop("vit_depth") + self._encoder_depth = kwargs.get("depth", 5) + kwargs.update({"depth": self._vit_depth}) + super().__init__(**kwargs) + self._out_chans = kwargs.get("out_chans", 256) + self._patch_size = kwargs.get("patch_size", 16) + self._embed_dim = kwargs.get("embed_dim", 768) + self._validate() + self.intermediate_necks = nn.ModuleList( + [self.init_neck(self._embed_dim, out_chan) for out_chan in self.out_channels[:-1]] + ) + + @staticmethod + def init_neck(embed_dim: int, out_chans: int) -> nn.Module: + # Use similar neck as in ImageEncoderViT + return nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + @staticmethod + def neck_forward(neck: nn.Module, x: torch.Tensor, scale_factor: float = 1) -> torch.Tensor: + x = x.permute(0, 3, 1, 2) + if scale_factor != 1.0: + x = nn.functional.interpolate(x, scale_factor=scale_factor, mode="bilinear") + return neck(x) + + def requires_grad_(self, requires_grad: bool = True): + # Keep the intermediate necks trainable + for param in self.parameters(): + param.requires_grad_(requires_grad) + for param in self.intermediate_necks.parameters(): + param.requires_grad_(True) + return self + + @property + def output_stride(self): + return 32 + + @property + def out_channels(self): + return [self._out_chans // (2**i) for i in range(self._encoder_depth + 1)][::-1] + + def _validate(self): + # check vit depth + if self._vit_depth not in [12, 24, 32]: + raise ValueError(f"vit_depth must be one of [12, 24, 32], got {self._vit_depth}") + # check output + scale_factor = self._get_scale_factor() + if scale_factor != self._encoder_depth: + raise ValueError( + f"With patch_size={self._patch_size} and depth={self._encoder_depth}, " + "spatial dimensions of model output will not match input spatial dimensions. " + "It is recommended to set encoder depth=4 with default vit patch_size=16." + ) + + def _get_scale_factor(self) -> float: + """Input image will be downscale by this factor""" + return int(math.log(self._patch_size, 2)) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + features = [] + skip_steps = self._vit_depth // self._encoder_depth + scale_factor = self._get_scale_factor() + for i, blk in enumerate(self.blocks): + x = blk(x) + if i % skip_steps == 0: + # Double spatial dimension and halve number of channels + neck = self.intermediate_necks[i // skip_steps] + features.append(self.neck_forward(neck, x, scale_factor=2**scale_factor)) + scale_factor -= 1 + + x = self.neck(x.permute(0, 3, 1, 2)) + features.append(x) + + return features + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) -> None: + # Exclude mask_decoder and prompt encoder weights + # and remove 'image_encoder.' prefix + state_dict = { + k.replace("image_encoder.", ""): v + for k, v in state_dict.items() + if not k.startswith("mask_decoder") and not k.startswith("prompt_encoder") + } + missing, unused = super().load_state_dict(state_dict, strict=False) + missing = list(filter(lambda x: not x.startswith("intermediate_necks"), missing)) + if len(missing) + len(unused) > 0: + n_loaded = len(state_dict) - len(missing) - len(unused) + warnings.warn( + f"Only {n_loaded} out of pretrained {len(state_dict)} SAM image encoder modules are loaded. " + f"Missing modules: {missing}. Unused modules: {unused}." + ) + + +sam_vit_encoders = { + "sam-vit_h": { + "encoder": SamVitEncoder, + "pretrained_settings": { + "sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"}, + }, + "params": dict( + embed_dim=1280, + vit_depth=32, + num_heads=16, + global_attn_indexes=[7, 15, 23, 31], + ), + }, + "sam-vit_l": { + "encoder": SamVitEncoder, + "pretrained_settings": { + "sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"}, + }, + "params": dict( + embed_dim=1024, + vit_depth=24, + num_heads=16, + global_attn_indexes=[5, 11, 17, 23], + ), + }, + "sam-vit_b": { + "encoder": SamVitEncoder, + "pretrained_settings": { + "sa-1b": {"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"}, + }, + "params": dict( + embed_dim=768, + vit_depth=12, + num_heads=12, + global_attn_indexes=[2, 5, 8, 11], + ), + }, +} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_models.py b/tests/test_models.py index c2e6d941..08e87b91 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,6 +14,9 @@ def get_encoders(): "resnext101_32x16d", "resnext101_32x32d", "resnext101_32x48d", + "sam-vit_h", + "sam-vit_l", + "sam-vit_b", ] encoders = smp.encoders.get_encoder_names() encoders = [e for e in encoders if e not in exclude_encoders] diff --git a/tests/test_sam.py b/tests/test_sam.py new file mode 100644 index 00000000..2c377fa9 --- /dev/null +++ b/tests/test_sam.py @@ -0,0 +1,72 @@ +import pytest +import torch + +import segmentation_models_pytorch as smp +from segmentation_models_pytorch.encoders import get_encoder +from tests.test_models import get_sample, _test_forward, _test_forward_backward + + +@pytest.mark.parametrize("encoder_name", ["sam-vit_b", "sam-vit_l"]) +@pytest.mark.parametrize("img_size", [64, 128]) +@pytest.mark.parametrize("patch_size,depth", [(8, 3), (16, 4)]) +@pytest.mark.parametrize("vit_depth", [12, 24]) +def test_sam_encoder(encoder_name, img_size, patch_size, depth, vit_depth): + encoder = get_encoder(encoder_name, img_size=img_size, patch_size=patch_size, depth=depth, vit_depth=vit_depth) + assert encoder.output_stride == 32 + assert encoder.out_channels == [256 // (2**i) for i in range(depth + 1)][::-1] + + sample = torch.ones(1, 3, img_size, img_size) + with torch.no_grad(): + out = encoder(sample) + + assert len(out) == depth + 1 + + expected_spatial_size = img_size // patch_size + expected_chans = 256 + for i in range(1, len(out)): + assert out[-i].size() == torch.Size([1, expected_chans, expected_spatial_size, expected_spatial_size]) + expected_spatial_size *= 2 + expected_chans //= 2 + + +def test_sam_encoder_trainable(): + encoder = get_encoder("sam-vit_b", depth=4) + + encoder.requires_grad_(False) + for name, param in encoder.named_parameters(): + if name.startswith("intermediate_necks"): + assert param.requires_grad + else: + assert not param.requires_grad + + encoder.requires_grad_(True) + for param in encoder.parameters(): + assert param.requires_grad + + +def test_sam_encoder_validation_error(): + with pytest.raises(ValueError): + get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=5, vit_depth=12) + get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=None) + get_encoder("sam-vit_b", img_size=64, patch_size=16, depth=4, vit_depth=6) + + +@pytest.mark.parametrize("model_class", [smp.Unet]) +@pytest.mark.parametrize("decoder_channels,encoder_depth", [([64, 32, 16, 8], 4), ([64, 32, 16, 8], 4)]) +def test_sam_encoder_arch(model_class, decoder_channels, encoder_depth): + img_size = 1024 + model = model_class( + "sam-vit_b", + encoder_weights=None, + encoder_depth=encoder_depth, + decoder_channels=decoder_channels, + ) + smp = torch.ones(1, 3, img_size, img_size) + _test_forward_backward(model, smp, test_shape=True) + + +@pytest.mark.skip(reason="Run this test manually as it needs to download weights") +def test_sam_encoder_weights(): + smp.create_model( + "unet", encoder_name="sam-vit_b", encoder_depth=4, encoder_weights="sa-1b", decoder_channels=[64, 32, 16, 8] + )