1
+ from __future__ import print_function
2
+
3
+ import os
4
+ import numpy as np
5
+ from PIL import Image
6
+ from typing import Union
7
+
8
+ from modules import devices
9
+ from annotator .util import load_model
10
+ from annotator .annotator_path import models_path
11
+
12
+ from controlnet_aux import SamDetector
13
+ from controlnet_aux .segment_anything import sam_model_registry , SamAutomaticMaskGenerator
14
+
15
+ class SamDetector_Aux (SamDetector ):
16
+
17
+ model_dir = os .path .join (models_path , "mobile_sam" )
18
+
19
+ def __init__ (self , mask_generator : SamAutomaticMaskGenerator , sam ):
20
+ super ().__init__ (mask_generator )
21
+ self .device = devices .device
22
+ self .model = sam .to (self .device ).eval ()
23
+
24
+ @classmethod
25
+ def from_pretrained (cls ):
26
+ """
27
+ Possible model_type : vit_h, vit_l, vit_b, vit_t
28
+ download weights from https://huggingface.co/dhkim2810/MobileSAM
29
+ """
30
+ remote_url = os .environ .get (
31
+ "CONTROLNET_MOBILE_SAM_MODEL_URL" ,
32
+ "https://huggingface.co/dhkim2810/MobileSAM/resolve/main/mobile_sam.pt" ,
33
+ )
34
+ model_path = load_model (
35
+ "mobile_sam.pt" , remote_url = remote_url , model_dir = cls .model_dir
36
+ )
37
+
38
+ sam = sam_model_registry ["vit_t" ](checkpoint = model_path )
39
+
40
+ cls .model = sam .to (devices .device ).eval ()
41
+
42
+ mask_generator = SamAutomaticMaskGenerator (cls .model )
43
+
44
+ return cls (mask_generator , sam )
45
+
46
+ def __call__ (self , input_image : Union [np .ndarray , Image .Image ]= None , detect_resolution = 512 , image_resolution = 512 , output_type = "cv2" , ** kwargs ) -> np .ndarray :
47
+ self .model .to (self .device )
48
+ image = super ().__call__ (input_image = input_image , detect_resolution = detect_resolution , image_resolution = image_resolution , output_type = output_type , ** kwargs )
49
+ return np .array (image ).astype (np .uint8 )
0 commit comments