-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSegmenter.py
48 lines (34 loc) · 1.24 KB
/
Segmenter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import cv2
import matplotlib.image as mpimg
from PIL import Image
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from image_classifier import image_classifier
from tqdm import tqdm
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cpu"
ID_colonnina = 666
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
def Segmenter(image_path):
image = mpimg.imread(image_path)
masks = mask_generator.generate(image)
list_of_labels = []
#objects = {}
for _, mask in tqdm(enumerate(masks), total=len(masks), desc="Segmenting over masks"):
box = mask["bbox"]
x, y, w, h = box
cropped_image = image[y:y+h, x:x+w]
# qui metti il segmentedededdd
label = image_classifier(cropped_image)
list_of_labels.append(label)
# objects = {"stationID": ID_colonnina ,"time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),"object" : list_of_labels}
return list_of_labels