Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Jul 15, 2024
1 parent 3ff69b9 commit 1c6c5a4
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 0 deletions.
48 changes: 48 additions & 0 deletions annotator/mobile_sam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import print_function

import os
import numpy as np
from PIL import Image
from typing import Union

from modules import devices
from annotator.util import load_model
from annotator.annotator_path import models_path

from controlnet_aux import SamDetector
from controlnet_aux.segment_anything import sam_model_registry, SamAutomaticMaskGenerator

class SamDetector_Aux(SamDetector):

model_dir = os.path.join(models_path, "mobile_sam")

def __init__(self, mask_generator: SamAutomaticMaskGenerator):
super().__init__(mask_generator)

self.device = devices.device
self.model = SamDetector_Aux().to(self.device).eval()
self.from_pretrained(model_type="vit_t")

@classmethod
def from_pretrained(cls, model_type="vit_t"):
"""
Possible model_type : vit_h, vit_l, vit_b, vit_t
download weights from https://huggingface.co/dhkim2810/MobileSAM
"""
remote_url = os.environ.get(
"CONTROLNET_MOBILE_SAM_MODEL_URL",
"https://huggingface.co/dhkim2810/MobileSAM/resolve/main/mobile_sam.pt",
)
model_path = load_model(
"mobile_sam.pt", remote_url=remote_url, model_dir=cls.model_dir
)

sam = sam_model_registry[model_type](checkpoint=model_path)

mask_generator = SamAutomaticMaskGenerator(sam)

return cls(mask_generator)

def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs) -> np.ndarray:
self.model.to(self.device)
super().__call__(image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ matplotlib
facexlib
timm<=0.9.5
pydantic<=1.10.17
controlnet_aux
30 changes: 30 additions & 0 deletions scripts/preprocessor/mobile_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
from skimage import morphology

from annotator.mobile_sam import SamDetector_Aux
from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter
from scripts.utils import resize_image_with_pad

class PreprocessorMobileSam(Preprocessor):
def __init__(self):
super().__init__(name="mobile_sam")
self.tags = ["Segmentation"]
self.model = None

def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
**kwargs
):
img, remove_pad = resize_image_with_pad(input_image, resolution)
if self.model is None:
self.model = SamDetector_Aux()

result = self.model(img, detect_resolution=resolution, image_resolution=resolution)
return remove_pad(result)

Preprocessor.add_supported_preprocessor(PreprocessorMobileSam())

0 comments on commit 1c6c5a4

Please sign in to comment.