diff --git a/samm-python-terminal/utl_sam_server.py b/samm-python-terminal/utl_sam_server.py index 91c8084..604a7e3 100644 --- a/samm-python-terminal/utl_sam_server.py +++ b/samm-python-terminal/utl_sam_server.py @@ -1,16 +1,30 @@ from utl_sam_msg import * import numpy as np from tqdm import tqdm -import sys,os, cv2, matplotlib.pyplot as plt +import os, cv2 import torch, functools, pickle -# os.environ["CUDA_VISIBLE_DEVICES"] = "0"# -from segment_anything import sam_model_registry as sam_model_registry_sam -from segment_anything import SamPredictor as SamPredictor_sam -from mobile_sam import sam_model_registry as sam_model_registry_mobile -from mobile_sam import SamPredictor as SamPredictor_mobile - -from utl_latencylogger import latency_logger +# Third-party model imports (guarded so the server can start in partial environments) +try: + from segment_anything import sam_model_registry as sam_model_registry_sam + from segment_anything import SamPredictor as SamPredictor_sam +except Exception as e: # pragma: no cover (environmental) + sam_model_registry_sam = {} + SamPredictor_sam = None + print(f"[SAMM WARN] segment_anything import failed: {e}") + +try: + from mobile_sam import sam_model_registry as sam_model_registry_mobile + from mobile_sam import SamPredictor as SamPredictor_mobile +except Exception as e: # pragma: no cover + sam_model_registry_mobile = {} + SamPredictor_mobile = None + print(f"[SAMM WARN] mobile_sam import failed: {e}") + +try: + from utl_latencylogger import latency_logger +except Exception: # pragma: no cover + latency_logger = lambda *args, **kwargs: None # noop fallback def singleton(cls): instances = {} @@ -37,15 +51,44 @@ def __init__(self): self.samPredictor = {"R": None, "G": None, "Y": None} self.initNetwork() - def initNetwork(self, model = "vit_b"): + def _select_device(self): + """Select the best available device with preference: XPU > CUDA > MPS > CPU.""" + device = "cpu" + try: + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = "xpu" + try: + torch.xpu.set_device(0) + except Exception: + pass + print("[SAMM INFO] Intel XPU detected.") + elif torch.cuda.is_available(): + device = "cuda" + print("[SAMM INFO] CUDA detected.") + elif torch.backends.mps.is_available(): + device = "mps" + print("[SAMM INFO] MPS detected.") + else: + print("[SAMM INFO] Using CPU.") + except Exception as e: + print(f"[SAMM WARN] Device probe failed ({e}); using CPU.") + device = "cpu" + return device + + def _safe_model_to(self, model, desc): + """Attempt to move model to current device, fallback to CPU if it fails.""" + try: + model.to(device=self.device) + except Exception as e: + if self.device != "cpu": + print(f"[SAMM WARN] Failed to move {desc} to '{self.device}' ({e}); falling back to CPU.") + self.device = "cpu" + model.to(device=self.device) + return model - # Load the segmentation model - if torch.cuda.is_available(): - self.device = "cuda" - print("[SAMM INFO] CUDA detected. Waiting for Model ...") - if torch.backends.mps.is_available(): - self.device = "mps" - print("[SAMM INFO] MPS detected. Waiting for Model ...") + def initNetwork(self, model = "vit_b"): + self.device = self._select_device() + print(f"[SAMM INFO] Initializing model '{model}' on {self.device} ...") workspace = os.path.dirname(os.path.abspath(__file__)) workspace = os.path.join(workspace, 'samm-workspace') @@ -71,7 +114,7 @@ def initNetworkSam(self, model): raise Exception("[SAMM ERROR] SAM model file is not in " + self.sam_checkpoint) model_type = model sam = sam_model_registry_sam[model_type](checkpoint=self.sam_checkpoint) - sam.to(device=self.device) + sam = self._safe_model_to(sam, "SAM model") self.samPredictor["R"] = SamPredictor_sam(sam) self.samPredictor["G"] = SamPredictor_sam(sam) @@ -87,7 +130,7 @@ def initNetworkMobile(self, model): raise Exception("[SAMM ERROR] SAM model file is not in " + self.sam_checkpoint) model_type = model[7:] sam = sam_model_registry_mobile[model_type](checkpoint=self.sam_checkpoint) - sam.to(device=self.device) + sam = self._safe_model_to(sam, "Mobile SAM model") self.samPredictor["R"] = SamPredictor_mobile(sam) self.samPredictor["G"] = SamPredictor_mobile(sam) @@ -103,7 +146,7 @@ def initNetworkMedSam(self, model): raise Exception("[SAMM ERROR] SAM model file is not in " + self.sam_checkpoint) model_type = model[7:] sam = sam_model_registry_sam[model_type](checkpoint=self.sam_checkpoint) - sam.to(device=self.device) + sam = self._safe_model_to(sam, "MedSAM model") self.samPredictor["R"] = SamPredictor_sam(sam) self.samPredictor["G"] = SamPredictor_sam(sam) @@ -115,6 +158,7 @@ def sammProcessingCallBack_SET_IMAGE_SIZE(msg): dataNode.mainVolume = np.zeros([msg["r"], msg["g"], msg["y"]], dtype = np.uint8) dataNode.N = {"R": msg["r"], "G": msg["g"], "Y": msg["y"]} dataNode.imageSize = [msg["r"], msg["g"], msg["y"]] + print(f"[SAMM DEBUG] Image size set to: {dataNode.imageSize}") return np.array([1],dtype=np.uint8).tobytes(), None def sammProcessingCallBack_SET_NTH_IMAGE(msg): @@ -204,7 +248,7 @@ def testImage(dataNode, n, points, view): img = Image.fromarray(seg[0]) img.save("testseg.png") -def helperPredict(dataNode, msg, points, labels, bbox2d): +def _helperPredict(dataNode, msg, points, labels, bbox2d): dataNode.samPredictor[msg["view"]].features = dataNode.features[msg["view"]][msg["n"]].to(dataNode.device) if isinstance(bbox2d, (np.ndarray, np.generic)): @@ -251,31 +295,25 @@ def sammProcessingCallBack_INFERENCE(msg): labels.append(0) seg = None - if len(points) > 0 and (bbox2d[0]!=-404): - + if len(points) > 0 and (bbox2d[0] != -404): points = np.array(points) point_labels = np.array(labels) bbox2d = np.array(bbox2d) - bbox2d[bbox2d<1] = 1 - seg = helperPredict(dataNode, msg, points, point_labels, bbox2d) - - elif len(points) == 0 and (bbox2d[0]!=-404): - + bbox2d[bbox2d < 1] = 1 + seg = _helperPredict(dataNode, msg, points, point_labels, bbox2d) + elif len(points) == 0 and (bbox2d[0] != -404): points = None point_labels = None bbox2d = np.array(bbox2d) - bbox2d[bbox2d<1] = 1 - seg = helperPredict(dataNode, msg, points, point_labels, bbox2d) - - elif len(points) > 0 and (bbox2d[0]==-404): - + bbox2d[bbox2d < 1] = 1 + seg = _helperPredict(dataNode, msg, points, point_labels, bbox2d) + elif len(points) > 0 and (bbox2d[0] == -404): points = np.array(points) point_labels = np.array(labels) bbox2d = None - seg = helperPredict(dataNode, msg, points, point_labels, bbox2d) - + seg = _helperPredict(dataNode, msg, points, point_labels, bbox2d) else: - seg = np.zeros([W, H],dtype=np.uint8) + seg = np.zeros([W, H], dtype=np.uint8) laglog.event_complete_inference() @@ -296,9 +334,16 @@ def sammProcessingCallBack_MODEL_SELECTION(msg): def sammProcessingCallBack_AUTO_SEG(msg): dataNode = SammParameterNode() print("[SAMM INFO] Received Auto_seg command.") + print(f"[SAMM DEBUG] Current imageSize: {dataNode.imageSize}") print(msg["segRangeMin"], msg["segRangeMax"], msg["segSlice"]) + # Check if imageSize has been properly initialized + if len(dataNode.imageSize) < 3: + print("[SAMM ERROR] Image size not set. Please set image size first.") + print(f"[SAMM ERROR] imageSize length: {len(dataNode.imageSize)}, content: {dataNode.imageSize}") + return np.array([0], dtype=np.uint8).tobytes(), None + W, H = dataNode.imageSize[1], dataNode.imageSize[2] seg = None @@ -312,7 +357,7 @@ def sammProcessingCallBack_AUTO_SEG(msg): bbox2d = np.array(bbox2d) bbox2d[bbox2d<1] = 1 - seg = helperPredict( + seg = _helperPredict( dataNode, {"view" : "R", "n" : msg["segSlice"]}, points, point_labels,