Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 81 additions & 36 deletions samm-python-terminal/utl_sam_server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
from utl_sam_msg import *
import numpy as np
from tqdm import tqdm
import sys,os, cv2, matplotlib.pyplot as plt
import os, cv2
import torch, functools, pickle
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"#

from segment_anything import sam_model_registry as sam_model_registry_sam
from segment_anything import SamPredictor as SamPredictor_sam
from mobile_sam import sam_model_registry as sam_model_registry_mobile
from mobile_sam import SamPredictor as SamPredictor_mobile

from utl_latencylogger import latency_logger
# Third-party model imports (guarded so the server can start in partial environments)
try:
from segment_anything import sam_model_registry as sam_model_registry_sam
from segment_anything import SamPredictor as SamPredictor_sam
except Exception as e: # pragma: no cover (environmental)
sam_model_registry_sam = {}
SamPredictor_sam = None
print(f"[SAMM WARN] segment_anything import failed: {e}")

try:
from mobile_sam import sam_model_registry as sam_model_registry_mobile
from mobile_sam import SamPredictor as SamPredictor_mobile
except Exception as e: # pragma: no cover
sam_model_registry_mobile = {}
SamPredictor_mobile = None
print(f"[SAMM WARN] mobile_sam import failed: {e}")

try:
from utl_latencylogger import latency_logger
except Exception: # pragma: no cover
latency_logger = lambda *args, **kwargs: None # noop fallback

def singleton(cls):
instances = {}
Expand All @@ -37,15 +51,44 @@ def __init__(self):
self.samPredictor = {"R": None, "G": None, "Y": None}
self.initNetwork()

def initNetwork(self, model = "vit_b"):
def _select_device(self):
"""Select the best available device with preference: XPU > CUDA > MPS > CPU."""
device = "cpu"
try:
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = "xpu"
try:
torch.xpu.set_device(0)
except Exception:
pass
print("[SAMM INFO] Intel XPU detected.")
elif torch.cuda.is_available():
device = "cuda"
print("[SAMM INFO] CUDA detected.")
elif torch.backends.mps.is_available():
device = "mps"
print("[SAMM INFO] MPS detected.")
else:
print("[SAMM INFO] Using CPU.")
except Exception as e:
print(f"[SAMM WARN] Device probe failed ({e}); using CPU.")
device = "cpu"
return device

def _safe_model_to(self, model, desc):
"""Attempt to move model to current device, fallback to CPU if it fails."""
try:
model.to(device=self.device)
except Exception as e:
if self.device != "cpu":
print(f"[SAMM WARN] Failed to move {desc} to '{self.device}' ({e}); falling back to CPU.")
self.device = "cpu"
model.to(device=self.device)
return model

# Load the segmentation model
if torch.cuda.is_available():
self.device = "cuda"
print("[SAMM INFO] CUDA detected. Waiting for Model ...")
if torch.backends.mps.is_available():
self.device = "mps"
print("[SAMM INFO] MPS detected. Waiting for Model ...")
def initNetwork(self, model = "vit_b"):
self.device = self._select_device()
print(f"[SAMM INFO] Initializing model '{model}' on {self.device} ...")

workspace = os.path.dirname(os.path.abspath(__file__))
workspace = os.path.join(workspace, 'samm-workspace')
Expand All @@ -71,7 +114,7 @@ def initNetworkSam(self, model):
raise Exception("[SAMM ERROR] SAM model file is not in " + self.sam_checkpoint)
model_type = model
sam = sam_model_registry_sam[model_type](checkpoint=self.sam_checkpoint)
sam.to(device=self.device)
sam = self._safe_model_to(sam, "SAM model")

self.samPredictor["R"] = SamPredictor_sam(sam)
self.samPredictor["G"] = SamPredictor_sam(sam)
Expand All @@ -87,7 +130,7 @@ def initNetworkMobile(self, model):
raise Exception("[SAMM ERROR] SAM model file is not in " + self.sam_checkpoint)
model_type = model[7:]
sam = sam_model_registry_mobile[model_type](checkpoint=self.sam_checkpoint)
sam.to(device=self.device)
sam = self._safe_model_to(sam, "Mobile SAM model")

self.samPredictor["R"] = SamPredictor_mobile(sam)
self.samPredictor["G"] = SamPredictor_mobile(sam)
Expand All @@ -103,7 +146,7 @@ def initNetworkMedSam(self, model):
raise Exception("[SAMM ERROR] SAM model file is not in " + self.sam_checkpoint)
model_type = model[7:]
sam = sam_model_registry_sam[model_type](checkpoint=self.sam_checkpoint)
sam.to(device=self.device)
sam = self._safe_model_to(sam, "MedSAM model")

self.samPredictor["R"] = SamPredictor_sam(sam)
self.samPredictor["G"] = SamPredictor_sam(sam)
Expand All @@ -115,6 +158,7 @@ def sammProcessingCallBack_SET_IMAGE_SIZE(msg):
dataNode.mainVolume = np.zeros([msg["r"], msg["g"], msg["y"]], dtype = np.uint8)
dataNode.N = {"R": msg["r"], "G": msg["g"], "Y": msg["y"]}
dataNode.imageSize = [msg["r"], msg["g"], msg["y"]]
print(f"[SAMM DEBUG] Image size set to: {dataNode.imageSize}")
return np.array([1],dtype=np.uint8).tobytes(), None

def sammProcessingCallBack_SET_NTH_IMAGE(msg):
Expand Down Expand Up @@ -204,7 +248,7 @@ def testImage(dataNode, n, points, view):
img = Image.fromarray(seg[0])
img.save("testseg.png")

def helperPredict(dataNode, msg, points, labels, bbox2d):
def _helperPredict(dataNode, msg, points, labels, bbox2d):

dataNode.samPredictor[msg["view"]].features = dataNode.features[msg["view"]][msg["n"]].to(dataNode.device)
if isinstance(bbox2d, (np.ndarray, np.generic)):
Expand Down Expand Up @@ -251,31 +295,25 @@ def sammProcessingCallBack_INFERENCE(msg):
labels.append(0)

seg = None
if len(points) > 0 and (bbox2d[0]!=-404):

if len(points) > 0 and (bbox2d[0] != -404):
points = np.array(points)
point_labels = np.array(labels)
bbox2d = np.array(bbox2d)
bbox2d[bbox2d<1] = 1
seg = helperPredict(dataNode, msg, points, point_labels, bbox2d)

elif len(points) == 0 and (bbox2d[0]!=-404):

bbox2d[bbox2d < 1] = 1
seg = _helperPredict(dataNode, msg, points, point_labels, bbox2d)
elif len(points) == 0 and (bbox2d[0] != -404):
points = None
point_labels = None
bbox2d = np.array(bbox2d)
bbox2d[bbox2d<1] = 1
seg = helperPredict(dataNode, msg, points, point_labels, bbox2d)

elif len(points) > 0 and (bbox2d[0]==-404):

bbox2d[bbox2d < 1] = 1
seg = _helperPredict(dataNode, msg, points, point_labels, bbox2d)
elif len(points) > 0 and (bbox2d[0] == -404):
points = np.array(points)
point_labels = np.array(labels)
bbox2d = None
seg = helperPredict(dataNode, msg, points, point_labels, bbox2d)

seg = _helperPredict(dataNode, msg, points, point_labels, bbox2d)
else:
seg = np.zeros([W, H],dtype=np.uint8)
seg = np.zeros([W, H], dtype=np.uint8)

laglog.event_complete_inference()

Expand All @@ -296,9 +334,16 @@ def sammProcessingCallBack_MODEL_SELECTION(msg):
def sammProcessingCallBack_AUTO_SEG(msg):
dataNode = SammParameterNode()
print("[SAMM INFO] Received Auto_seg command.")
print(f"[SAMM DEBUG] Current imageSize: {dataNode.imageSize}")

print(msg["segRangeMin"], msg["segRangeMax"], msg["segSlice"])

# Check if imageSize has been properly initialized
if len(dataNode.imageSize) < 3:
print("[SAMM ERROR] Image size not set. Please set image size first.")
print(f"[SAMM ERROR] imageSize length: {len(dataNode.imageSize)}, content: {dataNode.imageSize}")
return np.array([0], dtype=np.uint8).tobytes(), None

W, H = dataNode.imageSize[1], dataNode.imageSize[2]

seg = None
Expand All @@ -312,7 +357,7 @@ def sammProcessingCallBack_AUTO_SEG(msg):
bbox2d = np.array(bbox2d)
bbox2d[bbox2d<1] = 1

seg = helperPredict(
seg = _helperPredict(
dataNode,
{"view" : "R", "n" : msg["segSlice"]},
points, point_labels,
Expand Down