diff --git a/README.md b/README.md index e0a630a..ed3a17d 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,20 @@ python scripts/export_onnx_model.py --checkpoint ./weights/mobile_sam.pt --model Also check the [example notebook](https://github.com/ChaoningZhang/MobileSAM/blob/master/notebooks/onnx_model_example.ipynb) to follow detailed steps. We recommend to use `onnx==1.12.0` and `onnxruntime==1.13.1` which is tested. +## Pytorch Mobile Export + +**MobileSAM** can now be run on Pytorch Mobile. Export the model with + +``` +python ./scripts/convert_pytorch_mobile.py output_dir +``` + +The result can be loaded as described in https://pytorch.org/tutorials/prototype/ios_gpu_workflow.html + +BUT: The current version only runs on CPU on Pytorch Mobile. The metal backend is missing strided convolution as it seems. + +The caller still needs to provide input scaling and normalization, as it is done in +[set_image()](https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/predictor.py) in the predictor example. ## BibTex of our MobileSAM If you use MobileSAM in your research, please use the following BibTeX entry. :mega: Thank you! diff --git a/mobile_sam/modeling/mask_decoder.py b/mobile_sam/modeling/mask_decoder.py index 5d2fdb0..0d99c24 100644 --- a/mobile_sam/modeling/mask_decoder.py +++ b/mobile_sam/modeling/mask_decoder.py @@ -99,10 +99,8 @@ def forward( ) # Select the correct mask or masks for output - if multimask_output: - mask_slice = slice(1, None) - else: - mask_slice = slice(0, 1) + mask_slice = slice(1 if multimask_output else 0, None if multimask_output else 1) + masks = masks[:, mask_slice, :, :] iou_pred = iou_pred[:, mask_slice] @@ -137,8 +135,8 @@ def predict_masks( src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding = self.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [] - for i in range(self.num_mask_tokens): - hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + for i, output_hypernetwork_mlp in enumerate(self.output_hypernetworks_mlps): # range(self.num_mask_tokens): + hyper_in_list.append(output_hypernetwork_mlp(mask_tokens_out[:, i, :])) hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) diff --git a/mobile_sam/modeling/prompt_encoder.py b/mobile_sam/modeling/prompt_encoder.py index c3143f4..b1d76b0 100644 --- a/mobile_sam/modeling/prompt_encoder.py +++ b/mobile_sam/modeling/prompt_encoder.py @@ -194,7 +194,7 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: def forward(self, size: Tuple[int, int]) -> torch.Tensor: """Generate positional encoding for a grid of the specified size.""" h, w = size - device: Any = self.positional_encoding_gaussian_matrix.device + device: torch.device = self.positional_encoding_gaussian_matrix.device grid = torch.ones((h, w), device=device, dtype=torch.float32) y_embed = grid.cumsum(dim=0) - 0.5 x_embed = grid.cumsum(dim=1) - 0.5 diff --git a/mobile_sam/modeling/sam.py b/mobile_sam/modeling/sam.py index 45b9e7c..c8ccc35 100644 --- a/mobile_sam/modeling/sam.py +++ b/mobile_sam/modeling/sam.py @@ -16,9 +16,10 @@ from .prompt_encoder import PromptEncoder +MASK_THRESHOLD_DEFAULT: float = 0.0 +IMAGE_FORMAT_DEFAULT: float = "RGB" + class Sam(nn.Module): - mask_threshold: float = 0.0 - image_format: str = "RGB" def __init__( self, @@ -27,6 +28,8 @@ def __init__( mask_decoder: MaskDecoder, pixel_mean: List[float] = [123.675, 116.28, 103.53], pixel_std: List[float] = [58.395, 57.12, 57.375], + mask_threshold=MASK_THRESHOLD_DEFAULT, + image_format=IMAGE_FORMAT_DEFAULT ) -> None: """ SAM predicts object masks from an image and input prompts. @@ -46,15 +49,16 @@ def __init__( self.mask_decoder = mask_decoder self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + self.mask_threshold = mask_threshold + self.image_format = image_format @property def device(self) -> Any: return self.pixel_mean.device - @torch.no_grad() def forward( self, - batched_input: List[Dict[str, Any]], + batched_input: List[Dict[str, Union[torch.Tensor, Tuple[int, int]]]], multimask_output: bool, ) -> List[Dict[str, torch.Tensor]]: """ @@ -95,47 +99,68 @@ def forward( shape BxCxHxW, where H=W=256. Can be passed as mask input to subsequent iterations of prediction. """ - input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) - image_embeddings = self.image_encoder(input_images) - - outputs = [] - for image_record, curr_embedding in zip(batched_input, image_embeddings): - if "point_coords" in image_record: - points = (image_record["point_coords"], image_record["point_labels"]) - else: - points = None - sparse_embeddings, dense_embeddings = self.prompt_encoder( - points=points, - boxes=image_record.get("boxes", None), - masks=image_record.get("mask_inputs", None), - ) - low_res_masks, iou_predictions = self.mask_decoder( - image_embeddings=curr_embedding.unsqueeze(0), - image_pe=self.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - ) - masks = self.postprocess_masks( - low_res_masks, - input_size=image_record["image"].shape[-2:], - original_size=image_record["original_size"], - ) - masks = masks > self.mask_threshold - outputs.append( - { - "masks": masks, - "iou_predictions": iou_predictions, - "low_res_logits": low_res_masks, - } - ) - return outputs + with torch.no_grad(): + input_images_list = [] + for x in batched_input: + img = x["image"] # Needed for Torchscript support + assert isinstance(img, torch.Tensor) + processed_image = self.preprocess(torch.jit.annotate(torch.Tensor, img)) + input_images_list.append(processed_image) + + input_images = torch.stack(input_images_list, dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs: List[Dict[str, torch.Tensor]] = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + boxes = image_record["boxes"] if "boxes" in image_record else None + assert isinstance(boxes, Optional[torch.Tensor]) + boxes = torch.jit.annotate(Optional[torch.Tensor], boxes) + masks = image_record["mask_inputs"] if "mask_inputs" in image_record else None + assert isinstance(masks, Optional[torch.Tensor]) + if "point_coords" in image_record: + pc = image_record["point_coords"] + assert isinstance(pc, torch.Tensor) + pl = image_record["point_labels"] + assert isinstance(pl, torch.Tensor) + points = (pc, pl) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=boxes, + masks=masks, + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + orig_size = image_record["original_size"] + assert isinstance(orig_size, Tuple[int, int]) + img = image_record["image"] + assert isinstance(img, torch.Tensor) + masks = self.postprocess_masks( + low_res_masks, + input_size=img.shape[-2:], + original_size=orig_size, + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs def postprocess_masks( self, masks: torch.Tensor, - input_size: Tuple[int, ...], - original_size: Tuple[int, ...], + input_size: List[int], + original_size: Tuple[int, int], ) -> torch.Tensor: """ Remove padding and upscale masks to the original image size. diff --git a/mobile_sam/modeling/tiny_vit_sam.py b/mobile_sam/modeling/tiny_vit_sam.py index 93fa9f5..e370c75 100644 --- a/mobile_sam/modeling/tiny_vit_sam.py +++ b/mobile_sam/modeling/tiny_vit_sam.py @@ -79,8 +79,11 @@ def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path): super().__init__() self.in_chans = in_chans + assert self.in_chans > 0 self.hidden_chans = int(in_chans * expand_ratio) + assert self.hidden_chans > 0 self.out_chans = out_chans + assert self.out_chans > 0 self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) self.act1 = activation() @@ -177,7 +180,7 @@ def __init__(self, dim, input_resolution, depth, def forward(self, x): for blk in self.blocks: - if self.use_checkpoint: + if self.use_checkpoint and not torch.jit.is_scripting(): x = checkpoint.checkpoint(blk, x) else: x = blk(x) @@ -335,7 +338,9 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, def forward(self, x): H, W = self.input_resolution B, L, C = x.shape - assert L == H * W, "input feature has wrong size" + assert L == H * W, f"input feature has wrong size: {L} != {H} * {W}" + assert H > 0, "height is 0" + assert W > 0, "width is 0" res_x = x if H == self.window_size and W == self.window_size: x = self.attn(x) @@ -346,9 +351,17 @@ def forward(self, x): pad_r = (self.window_size - W % self.window_size) % self.window_size padding = pad_b > 0 or pad_r > 0 - if padding: - x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + # Alternative to the above (pytorch lite doesn't come with F.pad on metal): + # if pad_b > 0: + # pad_tensor_b = torch.empty(size=(B, pad_b, W, C), dtype=x.dtype, device=x.device) + # x = torch.cat([x, pad_tensor_b], dim=1) # Concatenate it to the bottom of the height dimension + # + # if pad_r > 0: + # pad_tensor_r = torch.empty(size=(B, H + pad_b, pad_r, C), dtype=x.dtype, device=x.device) + # x = torch.cat([x, pad_tensor_r], dim=2) pH, pW = H + pad_b, W + pad_r nH = pH // self.window_size @@ -435,7 +448,7 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, def forward(self, x): for blk in self.blocks: - if self.use_checkpoint: + if self.use_checkpoint and not torch.jit.is_scripting(): x = checkpoint.checkpoint(blk, x) else: x = blk(x) @@ -446,6 +459,7 @@ def forward(self, x): def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() @@ -459,6 +473,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x + + class TinyViT(nn.Module): def __init__(self, img_size=224, in_chans=3, num_classes=1000, embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], @@ -496,24 +512,18 @@ def __init__(self, img_size=224, in_chans=3, num_classes=1000, # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): - kwargs = dict(dim=embed_dims[i_layer], - input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), - patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), - # input_resolution=(patches_resolution[0] // (2 ** i_layer), - # patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - downsample=PatchMerging if ( - i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint, - out_dim=embed_dims[min( - i_layer + 1, len(embed_dims) - 1)], - activation=activation, - ) if i_layer == 0: layer = ConvLayer( conv_expand_ratio=mbconv_expand_ratio, - **kwargs, + dim=embed_dims[i_layer], + input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)], + activation=activation, ) else: layer = BasicLayer( @@ -522,7 +532,15 @@ def __init__(self, img_size=224, in_chans=3, num_classes=1000, mlp_ratio=self.mlp_ratio, drop=drop_rate, local_conv_size=local_conv_size, - **kwargs) + dim=embed_dims[i_layer], + input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, + out_dim=embed_dims[min( i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) self.layers.append(layer) # Classifier head @@ -600,13 +618,11 @@ def no_weight_decay_keywords(self): def forward_features(self, x): # x: (N, C, H, W) x = self.patch_embed(x) - x = self.layers[0](x) start_i = 1 - for i in range(start_i, len(self.layers)): - layer = self.layers[i] - x = layer(x) + for i, layer in enumerate(self.layers[1:]): # range(start_i, len(self.layers)): + x = layer.forward(x) B,_,C=x.size() x = x.view(B, 64, 64, C) x=x.permute(0, 3, 1, 2) diff --git a/mobile_sam/utils/onnx.py b/mobile_sam/utils/onnx.py index 3196bdf..cf2f5fc 100644 --- a/mobile_sam/utils/onnx.py +++ b/mobile_sam/utils/onnx.py @@ -8,12 +8,11 @@ import torch.nn as nn from torch.nn import functional as F -from typing import Tuple +from typing import Tuple, Union from ..modeling import Sam from .amg import calculate_stability_score - class SamOnnxModel(nn.Module): """ This model should not be called directly, but is used in ONNX export. @@ -59,10 +58,8 @@ def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) point_labels == -1 ) - for i in range(self.model.prompt_encoder.num_point_embeddings): - point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ - i - ].weight * (point_labels == i) + for i, embedding in enumerate(self.model.prompt_encoder.point_embeddings): + point_embedding = point_embedding + embedding.weight * (point_labels == i) return point_embedding @@ -85,8 +82,8 @@ def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) - masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore orig_im_size = orig_im_size.to(torch.int64) - h, w = orig_im_size[0], orig_im_size[1] - masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + h, w = int(orig_im_size[0]), int(orig_im_size[1]) + masks = F.interpolate(masks, size=[h, w], mode="bilinear", align_corners=False) return masks def select_masks( @@ -104,7 +101,6 @@ def select_masks( return masks, iou_preds - @torch.no_grad() def forward( self, image_embeddings: torch.Tensor, @@ -113,7 +109,8 @@ def forward( mask_input: torch.Tensor, has_mask_input: torch.Tensor, orig_im_size: torch.Tensor, - ): + ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: sparse_embedding = self._embed_points(point_coords, point_labels) dense_embedding = self._embed_masks(mask_input, has_mask_input) @@ -135,7 +132,7 @@ def forward( upscaled_masks = self.mask_postprocessing(masks, orig_im_size) if self.return_extra_metrics: - stability_scores = calculate_stability_score( + stability_scores = calculate_stability_score( upscaled_masks, self.model.mask_threshold, self.stability_score_offset ) areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) diff --git a/scripts/convert_pytorch_mobile.py b/scripts/convert_pytorch_mobile.py new file mode 100644 index 0000000..aa24615 --- /dev/null +++ b/scripts/convert_pytorch_mobile.py @@ -0,0 +1,78 @@ +import argparse +import torch +import numpy as np +import matplotlib.pyplot as plt +import os +from torch import nn +from torch.utils.mobile_optimizer import optimize_for_mobile +from mobile_sam import sam_model_registry, SamPredictor +from mobile_sam.utils.onnx import SamOnnxModel + +if __name__ == '__main__': + parser = argparse.ArgumentParser(prog='Convert', description='Convert SAM model to Torchscript or CoreML')# + parser.add_argument("--model_type", default="vit_t", help="registered model type") + parser.add_argument("--checkpoint", default="./weights/mobile_sam.pt", help="model file") + parser.add_argument('output', help="Output directory.") + args = parser.parse_args() + + os.makedirs(args.output, exist_ok=True) + + print("Loading model...") + sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) + model = SamOnnxModel(sam, return_single_mask=True) + + embed_dim = sam.prompt_encoder.embed_dim + embed_size = sam.prompt_encoder.image_embedding_size + enc = sam.image_encoder.eval() + ex = torch.randn(1, 3, 1024, 1024, dtype=torch.float32) + mask_input_size = [4 * x for x in embed_size] + out = enc(ex) + + # def replace_gelu_with_tanh(model): + # for child_name, child_module in model.named_children(): + # if isinstance(child_module, nn.GELU): + # print("replacing gelu with tanh") + # setattr(model, child_name, nn.Tanh()) + # else: + # replace_gelu_with_tanh(child_module) + # replace_gelu_with_tanh(enc) + + embedding_model_ts = torch.jit.script( + enc, + example_inputs=[ex.unsqueeze(0)], # Why the hell is this unsqueeze necessary? + ) + + decoder_inputs = { + "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), + "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), + "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), + "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), + "has_mask_input": torch.tensor([1], dtype=torch.float), + "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), + } + + predictor_model_ts = torch.jit.script( + model, + example_inputs=[ + decoder_inputs.values() + ], + ) + + def save_pt(model, model_filename: str): + print("Optimizing for Pytorch Mobile") + torch.jit.save(model, os.path.join(args.output, f"{model_filename}.pt")) + print("before optimize", torch.jit.export_opnames(model)) + # torch.quantization.fuse_models... + model_cpu = optimize_for_mobile(model, backend="cpu") + print("after optimize for cpu: ", torch.jit.export_opnames(model_cpu)) + model_cpu._save_for_lite_interpreter(os.path.join(args.output, f"cpu_{model_filename}.ptl")) + model_metal = optimize_for_mobile(model, backend="metal") + print("after optimize for metal: ", torch.jit.export_opnames(model_metal)) + print(model_metal.code) + model_metal._save_for_lite_interpreter(os.path.join(args.output, f"metal_{model_filename}.ptl")) + + save_pt(embedding_model_ts, "vit_image_embedding") + save_pt(predictor_model_ts, "mobilesam_predictor") + + print("Done") +