Skip to content

Commit 0baecb5

Browse files
authored
Add mobile_sam with controlnet_aux (#3000)
* Add mobile_sam with controlnet_aux for CNXL_Union
1 parent 3ff69b9 commit 0baecb5

File tree

4 files changed

+77
-1
lines changed

4 files changed

+77
-1
lines changed

annotator/mobile_sam/__init__.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import print_function
2+
3+
import os
4+
import numpy as np
5+
from PIL import Image
6+
from typing import Union
7+
8+
from modules import devices
9+
from annotator.util import load_model
10+
from annotator.annotator_path import models_path
11+
12+
from controlnet_aux import SamDetector
13+
from controlnet_aux.segment_anything import sam_model_registry, SamAutomaticMaskGenerator
14+
15+
class SamDetector_Aux(SamDetector):
16+
17+
model_dir = os.path.join(models_path, "mobile_sam")
18+
19+
def __init__(self, mask_generator: SamAutomaticMaskGenerator, sam):
20+
super().__init__(mask_generator)
21+
self.device = devices.device
22+
self.model = sam.to(self.device).eval()
23+
24+
@classmethod
25+
def from_pretrained(cls):
26+
"""
27+
Possible model_type : vit_h, vit_l, vit_b, vit_t
28+
download weights from https://huggingface.co/dhkim2810/MobileSAM
29+
"""
30+
remote_url = os.environ.get(
31+
"CONTROLNET_MOBILE_SAM_MODEL_URL",
32+
"https://huggingface.co/dhkim2810/MobileSAM/resolve/main/mobile_sam.pt",
33+
)
34+
model_path = load_model(
35+
"mobile_sam.pt", remote_url=remote_url, model_dir=cls.model_dir
36+
)
37+
38+
sam = sam_model_registry["vit_t"](checkpoint=model_path)
39+
40+
cls.model = sam.to(devices.device).eval()
41+
42+
mask_generator = SamAutomaticMaskGenerator(cls.model)
43+
44+
return cls(mask_generator, sam)
45+
46+
def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="cv2", **kwargs) -> np.ndarray:
47+
self.model.to(self.device)
48+
image = super().__call__(input_image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs)
49+
return np.array(image).astype(np.uint8)

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ matplotlib
99
facexlib
1010
timm<=0.9.5
1111
pydantic<=1.10.17
12+
controlnet_aux

scripts/preprocessor/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
from .ip_adapter_auto import *
66
from .normal_dsine import *
77
from .model_free_preprocessors import *
8-
from .legacy.legacy_preprocessors import *
8+
from .legacy.legacy_preprocessors import *
9+
from .mobile_sam import *

scripts/preprocessor/mobile_sam.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from annotator.mobile_sam import SamDetector_Aux
2+
from scripts.supported_preprocessor import Preprocessor
3+
4+
class PreprocessorMobileSam(Preprocessor):
5+
def __init__(self):
6+
super().__init__(name="mobile_sam")
7+
self.tags = ["Segmentation"]
8+
self.model = None
9+
10+
def __call__(
11+
self,
12+
input_image,
13+
resolution,
14+
slider_1=None,
15+
slider_2=None,
16+
slider_3=None,
17+
**kwargs
18+
):
19+
if self.model is None:
20+
self.model = SamDetector_Aux.from_pretrained()
21+
22+
result = self.model(input_image, detect_resolution=resolution, image_resolution=resolution, output_type="cv2")
23+
return result
24+
25+
Preprocessor.add_supported_preprocessor(PreprocessorMobileSam())

0 commit comments

Comments
 (0)