From d0eb32665f32cb084106b69dfa068be3beb464b9 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 10 Sep 2025 21:35:10 +0000 Subject: [PATCH 1/5] feat(sv_split): track max sv id to create new ids; convert ws seg to ocdbt --- pychunkedgraph/graph/client/base.py | 6 +- .../graph/client/bigtable/client.py | 11 ++++ pychunkedgraph/graph/ocdbt.py | 63 +++++++++++++++++++ pychunkedgraph/ingest/cli.py | 2 + pychunkedgraph/ingest/cluster.py | 4 ++ pychunkedgraph/ingest/create/atomic_layer.py | 5 +- pychunkedgraph/tests/test_uncategorized.py | 18 ++++-- 7 files changed, 103 insertions(+), 6 deletions(-) create mode 100644 pychunkedgraph/graph/ocdbt.py diff --git a/pychunkedgraph/graph/client/base.py b/pychunkedgraph/graph/client/base.py index 953734670..5dd473bd2 100644 --- a/pychunkedgraph/graph/client/base.py +++ b/pychunkedgraph/graph/client/base.py @@ -115,13 +115,17 @@ class ClientWithIDGen(SimpleClient): """ @abstractmethod - def create_node_ids(self, chunk_id): + def create_node_ids(self, chunk_id, size: int): """Generate a range of unique IDs in the chunk.""" @abstractmethod def create_node_id(self, chunk_id): """Generate a unique ID in the chunk.""" + @abstractmethod + def set_max_node_id(self, chunk_id, node_id): + """Gets the current maximum node ID in the chunk.""" + @abstractmethod def get_max_node_id(self, chunk_id): """Gets the current maximum node ID in the chunk.""" diff --git a/pychunkedgraph/graph/client/bigtable/client.py b/pychunkedgraph/graph/client/bigtable/client.py index 9195fb397..56e5476c7 100644 --- a/pychunkedgraph/graph/client/bigtable/client.py +++ b/pychunkedgraph/graph/client/bigtable/client.py @@ -615,6 +615,17 @@ def create_node_id( """Generate a unique node ID in the chunk.""" return self.create_node_ids(chunk_id, 1, root_chunk=root_chunk)[0] + def set_max_node_id( + self, chunk_id: np.uint64, node_id: np.uint64 + ) -> basetypes.NODE_ID: + """Set max segment ID for a given chunk.""" + size = int(np.uint64(chunk_id) ^ np.uint64(node_id)) + key = serialize_uint64(chunk_id, counter=True) + column = attributes.Concurrency.Counter + row = self._table.append_row(key) + row.increment_cell_value(column.family_id, column.key, size) + row = row.commit() + def get_max_node_id( self, chunk_id: basetypes.CHUNK_ID, root_chunk=False ) -> basetypes.NODE_ID: diff --git a/pychunkedgraph/graph/ocdbt.py b/pychunkedgraph/graph/ocdbt.py new file mode 100644 index 000000000..03c6d9b65 --- /dev/null +++ b/pychunkedgraph/graph/ocdbt.py @@ -0,0 +1,63 @@ +import os +import numpy as np +import tensorstore as ts + +OCDBT_SEG_COMPRESSION_LEVEL = 17 + + +def get_seg_source_and_destination_ocdbt(ws_path: str, create: bool = False) -> tuple: + src_spec = { + "driver": "neuroglancer_precomputed", + "kvstore": ws_path, + } + src = ts.open(src_spec).result() + schema = src.schema + + ocdbt_path = os.path.join(ws_path, "ocdbt", "base") + dst_spec = { + "driver": "neuroglancer_precomputed", + "kvstore": { + "driver": "ocdbt", + "base": ocdbt_path, + "config": { + "compression": {"id": "zstd", "level": OCDBT_SEG_COMPRESSION_LEVEL}, + }, + }, + } + + dst = ts.open( + dst_spec, + create=create, + rank=schema.rank, + dtype=schema.dtype, + codec=schema.codec, + domain=schema.domain, + shape=schema.shape, + chunk_layout=schema.chunk_layout, + dimension_units=schema.dimension_units, + delete_existing=create, + ).result() + return (src, dst) + + +def copy_ws_chunk( + source, + destination, + chunk_size: tuple, + coords: list, + voxel_bounds: np.ndarray, +): + coords = np.array(coords, dtype=int) + chunk_size = np.array(chunk_size, dtype=int) + vx_start = coords * chunk_size + voxel_bounds[:, 0] + vx_end = vx_start + chunk_size + xE, yE, zE = voxel_bounds[:, 1] + + x0, y0, z0 = vx_start + x1, y1, z1 = vx_end + x1 = min(x1, xE) + y1 = min(y1, yE) + z1 = min(z1, zE) + + data = source[x0:x1, y0:y1, z0:z1].read().result() + destination[x0:x1, y0:y1, z0:z1].write(data).result() diff --git a/pychunkedgraph/ingest/cli.py b/pychunkedgraph/ingest/cli.py index c50525ec6..8d44bf276 100644 --- a/pychunkedgraph/ingest/cli.py +++ b/pychunkedgraph/ingest/cli.py @@ -23,6 +23,7 @@ from .simple_tests import run_all from .create.parent_layer import add_parent_chunk from ..graph.chunkedgraph import ChunkedGraph +from ..graph.ocdbt import get_seg_source_and_destination_ocdbt from ..utils.redis import get_redis_connection, keys as r_keys group_name = "ingest" @@ -71,6 +72,7 @@ def ingest_graph( imanager = IngestionManager(ingest_config, meta) enqueue_l2_tasks(imanager, create_atomic_chunk) + get_seg_source_and_destination_ocdbt(cg.meta, create=True) @ingest_cli.command("imanager") diff --git a/pychunkedgraph/ingest/cluster.py b/pychunkedgraph/ingest/cluster.py index 1ae13a353..33c2db2c5 100644 --- a/pychunkedgraph/ingest/cluster.py +++ b/pychunkedgraph/ingest/cluster.py @@ -26,6 +26,7 @@ from .upgrade.parent_layer import update_chunk as update_parent_chunk from ..graph.edges import EDGE_TYPES, Edges, put_edges from ..graph import ChunkedGraph, ChunkedGraphMeta +from ..graph.ocdbt import copy_ws_chunk, get_seg_source_and_destination_ocdbt from ..graph.chunks.hierarchy import get_children_chunk_coords from ..graph.utils.basetypes import NODE_ID from ..io.edges import get_chunk_edges @@ -127,6 +128,9 @@ def create_atomic_chunk(coords: Sequence[int]): logging.debug(f"{k}: {len(v)}") for k, v in chunk_edges_active.items(): logging.debug(f"active_{k}: {len(v)}") + + src, dst = get_seg_source_and_destination_ocdbt(imanager.cg.meta) + copy_ws_chunk(imanager.cg, coords, src, dst) _post_task_completion(imanager, 2, coords) diff --git a/pychunkedgraph/ingest/create/atomic_layer.py b/pychunkedgraph/ingest/create/atomic_layer.py index 0a7aae728..25673855e 100644 --- a/pychunkedgraph/ingest/create/atomic_layer.py +++ b/pychunkedgraph/ingest/create/atomic_layer.py @@ -35,7 +35,10 @@ def add_atomic_chunk( return chunk_ids = cg.get_chunk_ids_from_node_ids(chunk_node_ids) - assert len(np.unique(chunk_ids)) == 1 + assert len(np.unique(chunk_ids)) == 1, np.unique(chunk_ids) + + max_node_id = np.max(chunk_node_ids) + cg.id_client.set_max_node_id(chunk_ids[0], max_node_id) graph, _, _, unique_ids = build_gt_graph(chunk_edge_ids, make_directed=True) ccs = connected_components(graph) diff --git a/pychunkedgraph/tests/test_uncategorized.py b/pychunkedgraph/tests/test_uncategorized.py index 5c2de29d4..766b81bca 100644 --- a/pychunkedgraph/tests/test_uncategorized.py +++ b/pychunkedgraph/tests/test_uncategorized.py @@ -107,6 +107,8 @@ def test_build_single_node(self, gen_graph): cg = gen_graph(n_layers=2) # Add Chunk A create_chunk(cg, vertices=[to_label(cg, 1, 0, 0, 0, 0)]) + chunk_id = to_label(cg, 1, 0, 0, 0, 0) + assert cg.id_client.get_max_node_id(chunk_id) == chunk_id res = cg.client._table.read_rows() res.consume_all() @@ -130,7 +132,7 @@ def test_build_single_node(self, gen_graph): assert len(children) == 1 and children[0] == to_label(cg, 1, 0, 0, 0, 0) # Make sure there are not any more entries in the table # include counters, meta and version rows - assert len(res.rows) == 1 + 1 + 1 + 1 + 1 + assert len(res.rows) == 1 + 1 + 1 + 1 + 1 + 1 @pytest.mark.timeout(30) def test_build_single_edge(self, gen_graph): @@ -151,6 +153,8 @@ def test_build_single_edge(self, gen_graph): vertices=[to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1)], edges=[(to_label(cg, 1, 0, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 1), 0.5)], ) + chunk_id = to_label(cg, 1, 0, 0, 0, 0) + assert cg.id_client.get_max_node_id(chunk_id) == to_label(cg, 1, 0, 0, 0, 1) res = cg.client._table.read_rows() res.consume_all() @@ -183,7 +187,7 @@ def test_build_single_edge(self, gen_graph): # Make sure there are not any more entries in the table # include counters, meta and version rows - assert len(res.rows) == 2 + 1 + 1 + 1 + 1 + assert len(res.rows) == 2 + 1 + 1 + 1 + 1 + 1 @pytest.mark.timeout(30) def test_build_single_across_edge(self, gen_graph): @@ -285,7 +289,7 @@ def test_build_single_across_edge(self, gen_graph): # Make sure there are not any more entries in the table # include counters, meta and version rows - assert len(res.rows) == 2 + 2 + 1 + 3 + 1 + 1 + assert len(res.rows) == 2 + 2 + 1 + 3 + 1 + 1 + 2 @pytest.mark.timeout(30) def test_build_single_edge_and_single_across_edge(self, gen_graph): @@ -311,6 +315,9 @@ def test_build_single_edge_and_single_across_edge(self, gen_graph): ], ) + chunk_id = to_label(cg, 1, 0, 0, 0, 0) + assert cg.id_client.get_max_node_id(chunk_id) == to_label(cg, 1, 0, 0, 0, 1) + # Chunk B create_chunk( cg, @@ -318,6 +325,9 @@ def test_build_single_edge_and_single_across_edge(self, gen_graph): edges=[(to_label(cg, 1, 1, 0, 0, 0), to_label(cg, 1, 0, 0, 0, 0), inf)], ) + chunk_id = to_label(cg, 1, 1, 0, 0, 0) + assert cg.id_client.get_max_node_id(chunk_id) == to_label(cg, 1, 1, 0, 0, 0) + add_parent_chunk(cg, 3, np.array([0, 0, 0]), n_threads=1) res = cg.client._table.read_rows() res.consume_all() @@ -393,7 +403,7 @@ def test_build_single_edge_and_single_across_edge(self, gen_graph): # Make sure there are not any more entries in the table # include counters, meta and version rows - assert len(res.rows) == 3 + 2 + 1 + 3 + 1 + 1 + assert len(res.rows) == 3 + 2 + 1 + 3 + 1 + 1 + 2 @pytest.mark.timeout(120) def test_build_big_graph(self, gen_graph): From 9aacb4361881efde7e2a5b4d488b2ce992f34573 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 10 Sep 2025 21:36:20 +0000 Subject: [PATCH 2/5] feat(sv_split): metadata changes to support ocdbt seg --- pychunkedgraph/graph/meta.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/pychunkedgraph/graph/meta.py b/pychunkedgraph/graph/meta.py index 83d670ffe..6a938f802 100644 --- a/pychunkedgraph/graph/meta.py +++ b/pychunkedgraph/graph/meta.py @@ -2,17 +2,16 @@ from datetime import timedelta from typing import Dict from typing import List -from typing import Tuple from typing import Sequence from collections import namedtuple import numpy as np from cloudvolume import CloudVolume +from pychunkedgraph.graph.ocdbt import get_seg_source_and_destination_ocdbt + from .utils.generic import compute_bitmasks from .chunks.utils import get_chunks_boundary -from ..utils.redis import keys as r_keys -from ..utils.redis import get_rq_queue from ..utils.redis import get_redis_connection @@ -64,9 +63,11 @@ def __init__( self._custom_data = custom_data self._ws_cv = None + self._ws_ocdbt = None self._layer_bounds_d = None self._layer_count = None self._bitmasks = None + self._ocdbt_seg = None @property def graph_config(self): @@ -91,15 +92,33 @@ def ws_cv(self): # useful to avoid md5 errors on high gcs load redis = get_redis_connection() cached_info = json.loads(redis.get(cache_key)) - self._ws_cv = CloudVolume(self._data_source.WATERSHED, info=cached_info) + self._ws_cv = CloudVolume( + self._data_source.WATERSHED, info=cached_info, progress=False + ) except Exception: - self._ws_cv = CloudVolume(self._data_source.WATERSHED) + self._ws_cv = CloudVolume(self._data_source.WATERSHED, progress=False) try: redis.set(cache_key, json.dumps(self._ws_cv.info)) except Exception: ... return self._ws_cv + @property + def ocdbt_seg(self) -> bool: + if self._ocdbt_seg is None: + self._ocdbt_seg = self._custom_data.get("seg", {}).get("ocdbt", False) + return self._ocdbt_seg + + @property + def ws_ocdbt(self): + assert self.ocdbt_seg, "make sure this pcg has segmentation in ocdbt format" + if self._ws_ocdbt: + return self._ws_ocdbt + + _, _ocdbt_seg = get_seg_source_and_destination_ocdbt(self.data_source.WATERSHED) + self._ws_ocdbt = _ocdbt_seg + return self._ws_ocdbt + @property def resolution(self): return self.ws_cv.resolution # pylint: disable=no-member @@ -235,11 +254,14 @@ def split_bounding_offset(self): @property def dataset_info(self) -> Dict: info = self.ws_cv.info # pylint: disable=no-member - info.update( { "chunks_start_at_voxel_offset": True, - "data_dir": self.data_source.WATERSHED, + "data_dir": ( + self.ws_ocdbt.kvstore.base.url + if self.ocdbt_seg + else self.data_source.WATERSHED + ), "graph": { "chunk_size": self.graph_config.CHUNK_SIZE, "bounding_box": [2048, 2048, 512], From 375f6203313edf53da95a93169ec62bda60f1dc6 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 10 Sep 2025 21:39:58 +0000 Subject: [PATCH 3/5] feat(sv_split): split sv, update seg and edges, read and write new edges from pcg --- pychunkedgraph/graph/attributes.py | 19 + pychunkedgraph/graph/chunkedgraph.py | 42 +- pychunkedgraph/graph/chunks/utils.py | 54 +- pychunkedgraph/graph/cutting_sv.py | 1284 ++++++++++++++++++++++ pychunkedgraph/graph/edits_sv.py | 439 ++++++++ pychunkedgraph/graph/types.py | 3 +- pychunkedgraph/graph/utils/__init__.py | 1 + pychunkedgraph/graph/utils/generic.py | 18 +- pychunkedgraph/graph/utils/id_helpers.py | 6 +- pychunkedgraph/meshing/meshgen_utils.py | 22 +- requirements.in | 3 + requirements.txt | 31 +- 12 files changed, 1878 insertions(+), 44 deletions(-) create mode 100644 pychunkedgraph/graph/cutting_sv.py create mode 100644 pychunkedgraph/graph/edits_sv.py diff --git a/pychunkedgraph/graph/attributes.py b/pychunkedgraph/graph/attributes.py index b431a159b..43d6777aa 100644 --- a/pychunkedgraph/graph/attributes.py +++ b/pychunkedgraph/graph/attributes.py @@ -132,6 +132,13 @@ class Connectivity: serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2)), ) + # new edges as a result of supervoxel split + SplitEdges = _Attribute( + key=b"split_edges", + family_id="4", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2)), + ) + class Hierarchy: Child = _Attribute( @@ -160,6 +167,18 @@ class Hierarchy: serializer=serializers.NumPyValue(dtype=basetypes.NODE_ID), ) + FormerIdentity = _Attribute( + key=b"former_ids", + family_id="0", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID), + ) + + NewIdentity = _Attribute( + key=b"new_ids", + family_id="0", + serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID), + ) + class GraphMeta: key = b"meta" diff --git a/pychunkedgraph/graph/chunkedgraph.py b/pychunkedgraph/graph/chunkedgraph.py index 1754315d8..7631d31c8 100644 --- a/pychunkedgraph/graph/chunkedgraph.py +++ b/pychunkedgraph/graph/chunkedgraph.py @@ -626,22 +626,44 @@ def get_subgraph_leaves( self, node_id_or_ids, bbox, bbox_is_coordinate, False, True ) - def get_fake_edges( + def get_edited_edges( self, chunk_ids: np.ndarray, time_stamp: datetime.datetime = None ) -> typing.Dict: + """ + Edges stored within a pcg that were created as a result of edits. + Either 'fake' edges that were adding for a merge edit; + Or 'split' edges resulting from a supervoxel split. + """ result = {} - fake_edges_d = self.client.read_nodes( + properties = [ + attributes.Connectivity.FakeEdges, + attributes.Connectivity.SplitEdges, + attributes.Connectivity.Affinity, + attributes.Connectivity.Area, + ] + _edges_d = self.client.read_nodes( node_ids=chunk_ids, - properties=attributes.Connectivity.FakeEdges, + properties=properties, end_time=time_stamp, end_time_inclusive=True, fake_edges=True, ) - for id_, val in fake_edges_d.items(): - edges = np.concatenate( - [np.array(e.value, dtype=basetypes.NODE_ID, copy=False) for e in val] - ) - result[id_] = Edges(edges[:, 0], edges[:, 1]) + for id_, val in _edges_d.items(): + edges = val.get(attributes.Connectivity.FakeEdges, []) + edges = np.concatenate([types.empty_2d, *[e.value for e in edges]]) + fake_edges_ = Edges(edges[:, 0], edges[:, 1]) + + edges = val.get(attributes.Connectivity.SplitEdges, []) + edges = np.concatenate([types.empty_2d, *[e.value for e in edges]]) + + aff = val.get(attributes.Connectivity.Affinity, []) + aff = np.concatenate([types.empty_affinities, *[e.value for e in aff]]) + + areas = val.get(attributes.Connectivity.Area, []) + areas = np.concatenate([types.empty_areas, *[e.value for e in areas]]) + split_edges_ = Edges(edges[:, 0], edges[:, 1], affinities=aff, areas=areas) + + result[id_] = fake_edges_ + split_edges_ return result def copy_fake_edges(self, chunk_id: np.uint64) -> None: @@ -680,10 +702,10 @@ def get_l2_agglomerations( if self.mock_edges is None: edges_d = self.read_chunk_edges(chunk_ids) - fake_edges = self.get_fake_edges(chunk_ids) + edited_edges = self.get_edited_edges(chunk_ids) all_chunk_edges = reduce( lambda x, y: x + y, - chain(edges_d.values(), fake_edges.values()), + chain(edges_d.values(), edited_edges.values()), Edges([], []), ) if self.mock_edges is not None: diff --git a/pychunkedgraph/graph/chunks/utils.py b/pychunkedgraph/graph/chunks/utils.py index f22a4d84a..12ec54929 100644 --- a/pychunkedgraph/graph/chunks/utils.py +++ b/pychunkedgraph/graph/chunks/utils.py @@ -167,9 +167,7 @@ def _compute_chunk_id( z: int, ) -> np.uint64: s_bits_per_dim = meta.bitmasks[layer] - if not ( - x < 2**s_bits_per_dim and y < 2**s_bits_per_dim and z < 2**s_bits_per_dim - ): + if not (x < 2**s_bits_per_dim and y < 2**s_bits_per_dim and z < 2**s_bits_per_dim): raise ValueError( f"Coordinate is out of range \ layer: {layer} bits/dim {s_bits_per_dim}. \ @@ -238,3 +236,53 @@ def get_bounding_children_chunks( if return_unique: return np.unique(result, axis=0) if result.size else result return result + + +def chunks_overlapping_bbox(bbox_min, bbox_max, chunk_size) -> dict: + """ + Find octree chunks overlapping with a bounding box in 3D + and return a dictionary mapping chunk indices to clipped bounding boxes. + """ + bbox_min = np.asarray(bbox_min, dtype=int) + bbox_max = np.asarray(bbox_max, dtype=int) + chunk_size = np.asarray(chunk_size, dtype=int) + + start_idx = np.floor_divide(bbox_min, chunk_size).astype(int) + end_idx = np.floor_divide(bbox_max, chunk_size).astype(int) + + ix = np.arange(start_idx[0], end_idx[0] + 1) + iy = np.arange(start_idx[1], end_idx[1] + 1) + iz = np.arange(start_idx[2], end_idx[2] + 1) + grid = np.stack(np.meshgrid(ix, iy, iz, indexing="ij"), axis=-1, dtype=int) + grid = grid.reshape(-1, 3) + + chunk_min = grid * chunk_size + chunk_max = chunk_min + chunk_size + clipped_min = np.maximum(chunk_min, bbox_min) + clipped_max = np.minimum(chunk_max, bbox_max) + return { + tuple(idx): np.stack([cmin, cmax], axis=0, dtype=int) + for idx, cmin, cmax in zip(grid, clipped_min, clipped_max) + } + + +def get_neighbors(coord, inclusive: bool = True, min_coord=None, max_coord=None): + """ + Get all valid coordinates in the 3×3×3 cube around a given chunk, + including the chunk itself (if inclusive=True), + respecting bounding box constraints. + """ + offsets = np.array(np.meshgrid([-1, 0, 1], [-1, 0, 1], [-1, 0, 1])).T.reshape(-1, 3) + if not inclusive: + offsets = offsets[~np.all(offsets == 0, axis=1)] + + neighbors = np.array(coord) + offsets + if min_coord is None: + min_coord = (0, 0, 0) + min_coord = np.array(min_coord) + neighbors = neighbors[(neighbors >= min_coord).all(axis=1)] + + if max_coord is not None: + max_coord = np.array(max_coord) + neighbors = neighbors[(neighbors <= max_coord).all(axis=1)] + return neighbors diff --git a/pychunkedgraph/graph/cutting_sv.py b/pychunkedgraph/graph/cutting_sv.py new file mode 100644 index 000000000..5f9ba58c5 --- /dev/null +++ b/pychunkedgraph/graph/cutting_sv.py @@ -0,0 +1,1284 @@ +from time import perf_counter + +import numpy as np +from typing import Dict, Tuple, Optional, Sequence +from scipy.spatial import cKDTree + + +# EDT backends: prefer Seung-Lab edt, fallback to scipy.ndimage +try: + from edt import edt as _edt_fast + + _HAVE_EDT_FAST = True +except Exception: + _HAVE_EDT_FAST = False + +from scipy import ndimage as ndi +from scipy.spatial import cKDTree +from skimage.graph import MCP_Geometric +from skimage.morphology import ( + ball, +) # keep only ball; use ndi.binary_dilation everywhere + +# ---------- Fast CC wrappers ---------- +try: + import cc3d + + _HAVE_CC3D = True +except Exception: + _HAVE_CC3D = False + from skimage.measure import label as _sk_label + +try: + import fastremap as _fr + + _HAVE_FASTREMAP = True +except Exception: + _HAVE_FASTREMAP = False + + +def _cc_label_26(mask: np.ndarray): + """ + Fast 3D connected components (26-connectivity). + Returns (labels:int32, n_components:int). + """ + if _HAVE_CC3D: + lbl = cc3d.connected_components( + mask.astype(np.uint8, copy=False), connectivity=26, out_dtype=np.uint32 + ) + return lbl, int(lbl.max()) + # Fallback: skimage (connectivity=3 ~ 26-neighborhood) + lbl = _sk_label(mask, connectivity=3).astype(np.int32, copy=False) + return lbl, int(lbl.max()) + + +def _largest_component_id(lbl: np.ndarray): + """ + Return the label ID (>=1) of the largest component in 'lbl'. + lbl should already be a CC label image where 0=background. + """ + if _HAVE_FASTREMAP: + u, counts = _fr.unique(lbl, return_counts=True) + if u.size: + bg = np.where(u == 0)[0] + if bg.size: + counts[bg[0]] = 0 + return int(u[np.argmax(counts)]) + return 0 + cnt = np.bincount(lbl.ravel()) + if cnt.size: + cnt[0] = 0 + return int(np.argmax(cnt)) if cnt.size else 0 + + +# ========================= +# Order / utility helpers +# ========================= +def _to_zyx_sampling(vs, vox_order): + vs = tuple(map(float, vs)) + if vox_order.lower() == "xyz": # (x,y,z) -> (z,y,x) + return (vs[2], vs[1], vs[0]) + if vox_order.lower() == "zyx": + return vs + raise ValueError("vox_order must be 'xyz' or 'zyx'") + + +def _to_internal_zyx_volume(vol, vol_order): + if vol_order.lower() == "zyx": + return vol, False + if vol_order.lower() == "xyz": # (x,y,z) -> (z,y,x) + return np.transpose(vol, (2, 1, 0)), True + raise ValueError("vol_order must be 'xyz' or 'zyx'") + + +def _from_internal_zyx_volume(vol_zyx, vol_order): + if vol_order.lower() == "zyx": + return vol_zyx + if vol_order.lower() == "xyz": # (z,y,x) -> (x,y,z) + return np.transpose(vol_zyx, (2, 1, 0)) + raise ValueError("vol_order must be 'xyz' or 'zyx'") + + +def _seeds_to_zyx(seeds, seed_order): + arr = np.asarray(seeds, dtype=float).reshape(-1, 3) + if seed_order.lower() == "xyz": + arr = arr[:, [2, 1, 0]] # (x,y,z) -> (z,y,x) + elif seed_order.lower() != "zyx": + raise ValueError("seed_order must be 'xyz' or 'zyx'") + return np.round(arr).astype(int) + + +def _seeds_from_zyx(seeds_zyx, seed_order): + arr = np.asarray(seeds_zyx, dtype=int).reshape(-1, 3) + if seed_order.lower() == "xyz": + return arr[:, [2, 1, 0]] # (z,y,x) -> (x,y,z) + elif seed_order.lower() == "zyx": + return arr + else: + raise ValueError("seed_order must be 'xyz' or 'zyx'") + + +# ========================= +# Snapping (KDTree-based) +# ========================= +def _extract_mask_boundary(mask, erosion_iters=1): + """ + Extract boundary voxels of a 3D mask using binary erosion. + Boundary = mask & (~eroded(mask)) + + Parameters: + mask : 3D boolean array + erosion_iters : number of erosion iterations (higher removes thicker border) + + Returns: + boundary_mask : 3D boolean array of the same shape + """ + if erosion_iters < 1: + # No erosion => boundary = mask (not recommended unless extremely thin structures) + return mask.copy() + + structure = np.ones((3, 3, 3), dtype=bool) + interior = ndi.binary_erosion( + mask, structure=structure, iterations=erosion_iters, border_value=0 + ) + boundary = mask & (~interior) + return boundary + + +def _downsample_points(points, mode="stride", stride=2, target=None, rng=None): + """ + Downsample a set of points (N,3) by either: + - 'stride': take one every 'stride' points (fast, deterministic), + - 'random': keep ~target points uniformly at random. + + Args: + points : (N, 3) int or float array of coordinates + mode : 'stride' or 'random' + stride : int >= 1 (for 'stride' mode) + target : number of points to keep (for 'random' mode); if None, default is 50k + rng : np.random.Generator for reproducible random sampling + + Returns: + (M, 3) array with M <= N + """ + n = points.shape[0] + if n == 0: + return points + + if mode == "stride": + stride = max(1, int(stride)) + return points[::stride] + + elif mode == "random": + if target is None: + target = min(n, 50_000) # default target + target = max(1, int(target)) + if target >= n: + return points + if rng is None: + rng = np.random.default_rng() + idx = rng.choice(n, size=target, replace=False) + return points[idx] + + else: + raise ValueError("downsample mode must be 'stride' or 'random'") + + +def snap_seeds_to_segment( + seeds_xyz, + mask, + mask_order="zyx", + voxel_size=(1.0, 1.0, 1.0), + use_boundary=True, + erosion_iters=1, + downsample=True, + downsample_mode="stride", # 'stride' or 'random' + downsample_stride=2, # used if mode='stride' + downsample_target=None, # used if mode='random' + rng=None, + return_index=False, + leafsize=16, + log=lambda x: None, + tag="snap", + method="kdtree", # accepted for compatibility; only 'kdtree' currently +): + """ + Snap seeds (in XYZ) to the closest True voxel of a 3D mask using cKDTree over + a *reduced* set of candidate voxels: + - boundary-only (mask & ~eroded(mask)), if use_boundary=True + - optionally downsampled (stride or random) + + This approach works well for speed while retaining high accuracy for snapping. + + Parameters: + seeds_xyz : (N,3) float or int array in XYZ order. + mask : 3D boolean array; binary segment. + mask_order : 'zyx' (default) or 'xyz' indicating memory layout of mask. + voxel_size : (vx, vy, vz) in XYZ physical units (e.g., (8.0, 8.0, 40.0)). + use_boundary : If True, only use boundary voxels for KDTree. + erosion_iters : Number of erosion iterations for boundary extraction. + downsample : If True, further reduce boundary points (stride or random). + downsample_mode : 'stride' or 'random' for boundary sampling. + downsample_stride : If stride mode, use every Nth boundary voxel. + downsample_target : If random mode, target number of boundary points to keep. + rng : Optional np.random.Generator for reproducible random sampling. + return_index : If True, also return indices of nearest boundary points. + leafsize : cKDTree leafsize parameter. + log : callable for logging + tag : string to prefix timings + method : currently only 'kdtree' supported. Present for backward compatibility. + + Returns: + snapped_xyz : (N,3) int array in XYZ order, coordinates within volume bounds. + match_idx : (optional) indices into the candidate points array, if return_index=True. + + Notes: + - Seeds outside the volume are supported; they will snap to the nearest segment voxel. + - If use_boundary=True yields no boundary (thin segment), we fall back to the full mask. + - If the mask is empty, we raise ValueError. + """ + t0 = perf_counter() + if method != "kdtree": + log(f"[{tag}] Warning: 'method={method}' not supported; using 'kdtree'.") + + # Validate mask + if mask.ndim != 3: + raise ValueError("mask must be a 3D boolean array") + if mask.dtype != bool: + mask = mask.astype(bool) + + if mask_order not in ("zyx", "xyz"): + raise ValueError("mask_order must be 'zyx' or 'xyz'") + + # Optional boundary extraction for speed + tb = perf_counter() + if use_boundary: + candidate_mask = _extract_mask_boundary(mask, erosion_iters=erosion_iters) + # Fallback to full mask if boundary is empty + if not candidate_mask.any(): + candidate_mask = mask + log(f"[{tag}] boundary empty → fallback to full mask") + else: + candidate_mask = mask + log(f"[{tag}] candidate extraction | {perf_counter()-tb:.3f}s") + + # Obtain candidate voxel coordinates in XYZ order + tc = perf_counter() + if mask_order == "zyx": + # mask shape is (Z, Y, X), np.where -> (z, y, x) + zc, yc, xc = np.where(candidate_mask) + points_xyz = np.stack([xc, yc, zc], axis=1) + max_x, max_y, max_z = mask.shape[2] - 1, mask.shape[1] - 1, mask.shape[0] - 1 + else: + # mask shape is (X, Y, Z), np.where -> (x, y, z) + xc, yc, zc = np.where(candidate_mask) + points_xyz = np.stack([xc, yc, zc], axis=1) + max_x, max_y, max_z = mask.shape[0] - 1, mask.shape[1] - 1, mask.shape[2] - 1 + log( + f"[{tag}] candidate coordinates | {perf_counter()-tc:.3f}s (n={len(points_xyz)})" + ) + + if points_xyz.shape[0] == 0: + raise ValueError( + "The mask (or boundary) contains no True voxels (empty segment)." + ) + + # Optional: further downsample candidate points + td = perf_counter() + if downsample: + before = len(points_xyz) + points_xyz = _downsample_points( + points_xyz, + mode=downsample_mode, + stride=downsample_stride, + target=downsample_target, + rng=rng, + ) + after = len(points_xyz) + log(f"[{tag}] downsample points {before} → {after} | {perf_counter()-td:.3f}s") + + # Prepare seeds array + seeds_xyz = np.asarray(seeds_xyz, dtype=np.float64) + if seeds_xyz.ndim == 1: + seeds_xyz = seeds_xyz[None, :] + if seeds_xyz.shape[1] != 3: + raise ValueError("seeds_xyz must be shape (N, 3)") + + # Scale coordinates to physical space to respect anisotropy + vx, vy, vz = voxel_size + scale = np.array([vx, vy, vz], dtype=np.float64) + + points_scaled = points_xyz * scale[None, :] + seeds_scaled = seeds_xyz * scale[None, :] + + # cKDTree nearest neighbor lookup + te = perf_counter() + tree = cKDTree(points_scaled, leafsize=leafsize) + _, nn_indices = tree.query(seeds_scaled, k=1, workers=-1) + log(f"[{tag}] KDTree build+query | {perf_counter()-te:.3f}s") + + # Map back to integer voxel coords (XYZ) + snapped_xyz = points_xyz[nn_indices].astype(np.int64) + + # Ensure snapped coords are valid (should already be in bounds) + snapped_xyz[:, 0] = np.clip(snapped_xyz[:, 0], 0, max_x) + snapped_xyz[:, 1] = np.clip(snapped_xyz[:, 1], 0, max_y) + snapped_xyz[:, 2] = np.clip(snapped_xyz[:, 2], 0, max_z) + + log(f"[{tag}] snapped {len(seeds_xyz)} seeds | total {perf_counter()-t0:.3f}s") + if return_index: + return snapped_xyz, nn_indices + else: + return snapped_xyz + + +# ============================================================ +# EDT wrapper (Seung-Lab edt preferred, fallback to scipy) +# ============================================================ +def _compute_edt(mask: np.ndarray, sampling_zyx, log=lambda x: None, tag="edt"): + """ + Compute Euclidean distance transform using Seung-Lab edt if available, + otherwise fallback to scipy.ndimage.distance_transform_edt. + + - mask: boolean array in ZYX order + - sampling_zyx: anisotropy tuple in ZYX (float) + """ + t0 = perf_counter() + if _HAVE_EDT_FAST: + dist = _edt_fast(mask.astype(np.uint8, copy=False), anisotropy=sampling_zyx) + log(f"[{tag}] Seung-Lab edt | {perf_counter()-t0:.3f}s") + return dist + else: + dist = ndi.distance_transform_edt(mask, sampling=sampling_zyx) + log(f"[{tag}] SciPy EDT | {perf_counter()-t0:.3f}s") + return dist + + +# ------------------------------------------------------------ +# Helpers for upsampling +# ------------------------------------------------------------ +def _upsample_bool(mask_ds, steps, target_shape): + up = mask_ds.repeat(steps[0], 0).repeat(steps[1], 1).repeat(steps[2], 2) + return up[: target_shape[0], : target_shape[1], : target_shape[2]] + + +def _upsample_labels(lbl_ds, steps, target_shape): + up = lbl_ds.repeat(steps[0], 0).repeat(steps[1], 1).repeat(steps[2], 2) + return up[: target_shape[0], : target_shape[1], : target_shape[2]] + + +# ============================================================ +# Combined connector (ROI + DS + MST paths) — uses snapping + fast EDT +# ============================================================ +def connect_both_seeds_via_ridge( + binary_sv: np.ndarray, + seeds_a, + seeds_b, + voxel_size=(1.0, 1.0, 1.0), + *, + vol_order: str = "xyz", + vox_order: str = "xyz", + seed_order: str = "xyz", + ridge_power: float = 2.0, + roi_pad_zyx=(24, 48, 48), + downsample=(2, 2, 1), + refine_fullres_when_fail: bool = True, + snap_method: str = "kdtree", + snap_kwargs: dict | None = None, + verbose: bool = True, +): + def log(msg: str): + if verbose: + print(msg, flush=True) + + def _bbox_pad_zyx(points_zyx, shape, pad=(24, 48, 48)): + pts = np.asarray(points_zyx, int) + if pts.size == 0: + return (0, 0, 0, shape[0], shape[1], shape[2]) + z0, y0, x0 = pts.min(0) + z1, y1, x1 = pts.max(0) + 1 + z0 = max(0, z0 - pad[0]) + y0 = max(0, y0 - pad[1]) + x0 = max(0, x0 - pad[2]) + z1 = min(shape[0], z1 + pad[0]) + y1 = min(shape[1], y1 + pad[1]) + x1 = min(shape[2], x1 + pad[2]) + return (z0, y0, x0, z1, y1, x1) + + def _mst_edges_phys(pts_zyx, sampling): + P = np.asarray(pts_zyx, float) + if len(P) <= 1: + return [] + S = np.array(sampling, float)[None, :] + phys = P * S + n = len(P) + in_tree = np.zeros(n, bool) + in_tree[0] = True + best = np.full(n, np.inf) + parent = np.full(n, -1, int) + d0 = np.sqrt(((phys - phys[0]) ** 2).sum(1)) + best[:] = d0 + best[0] = np.inf + parent[:] = 0 + edges = [] + for _ in range(n - 1): + i = int(np.argmin(best)) + if not np.isfinite(best[i]): + break + edges.append((int(parent[i]), i)) + in_tree[i] = True + best[i] = np.inf + di = np.sqrt(((phys - phys[i]) ** 2).sum(1)) + relax = (~in_tree) & (di < best) + parent[relax] = i + best[relax] = di[relax] + return edges + + t0 = perf_counter() + log( + f"[connect] vol_order={vol_order}, vox_order={vox_order}, seed_order={seed_order}" + ) + log( + f"[connect] mask shape: {binary_sv.shape}, ridge_power={ridge_power}, ds={downsample}" + ) + + sv_zyx, _ = _to_internal_zyx_volume(binary_sv, vol_order) + sampling = _to_zyx_sampling(voxel_size, vox_order) + + # SNAP seeds to mask + A_in_zyx = _seeds_to_zyx(seeds_a, seed_order) + B_in_zyx = _seeds_to_zyx(seeds_b, seed_order) + + # Default snapping config; override via snap_kwargs + snap_cfg = dict( + use_boundary=True, + erosion_iters=1, + downsample=True, + downsample_mode="random", + downsample_target=50_000, + method=snap_method, # allow pass-through compatibility + ) + if snap_kwargs is not None: + snap_cfg.update(snap_kwargs) + + def _snap(pts_zyx, name): + if pts_zyx.size == 0: + return np.empty((0, 3), dtype=int) + # Convert ZYX -> XYZ for snapper + pts_xyz = pts_zyx[:, [2, 1, 0]] + # Use snapping over full 3D sv_zyx with ZYX mask + snapped_xyz = snap_seeds_to_segment( + pts_xyz, + mask=sv_zyx, + mask_order="zyx", + voxel_size=( + sampling[2], + sampling[1], + sampling[0], + ), # convert ZYX->XYZ spacing + log=log, + tag=f"{name}@snap", + **snap_cfg, + ) + # Back to ZYX + return snapped_xyz[:, [2, 1, 0]] + + A_zyx = _snap(A_in_zyx, "A") + B_zyx = _snap(B_in_zyx, "B") + + if len(A_zyx) == 0 or len(B_zyx) == 0: + log("[connect] after snapping, one side has no seeds; skipping connection") + return ( + _seeds_from_zyx(A_zyx, seed_order), + _seeds_from_zyx(B_zyx, seed_order), + (len(A_zyx) > 0), + (len(B_zyx) > 0), + ) + + # ROI for speed + z0, y0, x0, z1, y1, x1 = _bbox_pad_zyx( + np.vstack([A_zyx, B_zyx]), sv_zyx.shape, pad=roi_pad_zyx + ) + roi = sv_zyx[z0:z1, y0:y1, x0:x1] + log(f"[connect] ROI: z[{z0}:{z1}] y[{y0}:{y1}] x[{x0}:{x1}] → shape {roi.shape}") + + # Downsample ROI + sz, sy, sx = map(int, downsample) + ti_ds = perf_counter() + if (sz, sy, sx) != (1, 1, 1): + roi_ds = roi[::sz, ::sy, ::sx] + else: + roi_ds = roi + sampling_ds = (sampling[0] * sz, sampling[1] * sy, sampling[2] * sx) + log( + f"[connect] ROI downsampled {roi.shape} -> {roi_ds.shape} | {perf_counter()-ti_ds:.3f}s" + ) + + # Robust seed placement on the downsampled grid: + # (1) Map to ROI-local coords + # (2) Divide by (sz,sy,sx) to approximate DS coords + # (3) SNAP them to the nearest True voxel in roi_ds using KDTree + def _to_roi_ds_snapped(pts_zyx, name="seedDS"): + if pts_zyx.size == 0: + return np.empty((0, 3), dtype=int) + local = np.asarray(pts_zyx, int) - np.array([z0, y0, x0]) # roi-local + seeds_ds = local / np.array( + [sz, sy, sx], dtype=float + ) # DS coordinates (float OK) + # Convert to XYZ for snapper + seeds_ds_xyz = seeds_ds[:, [2, 1, 0]] + try: + snapped_ds_xyz = snap_seeds_to_segment( + seeds_ds_xyz, + mask=roi_ds, + mask_order="zyx", + voxel_size=(sampling_ds[2], sampling_ds[1], sampling_ds[0]), + log=log, + tag=f"{name}@roi_ds", + use_boundary=False, + downsample=False, + method="kdtree", + ) + snapped_ds_zyx = snapped_ds_xyz[:, [2, 1, 0]] + return snapped_ds_zyx.astype(int) + except ValueError as e: + # If roi_ds is empty or degenerate, bail out gracefully: + log( + f"[{name}@roi_ds] snapping failed ({e}); falling back to nearest-int grid & mask check." + ) + approx = np.floor(seeds_ds + 0.5).astype(int) + Z, Y, X = roi_ds.shape + approx[:, 0] = np.clip(approx[:, 0], 0, Z - 1) + approx[:, 1] = np.clip(approx[:, 1], 0, Y - 1) + approx[:, 2] = np.clip(approx[:, 2], 0, X - 1) + # Keep only those approx coords that are inside mask + valid = [tuple(p) for p in approx if roi_ds[tuple(p)]] + return np.array(valid, dtype=int) + + A_ds = _to_roi_ds_snapped(A_zyx, "A") + B_ds = _to_roi_ds_snapped(B_zyx, "B") + + okA = len(A_ds) >= 1 + okB = len(B_ds) >= 1 + if not (okA and okB): + log( + "[connect] seeds disappeared or failed to map on DS grid; consider smaller ds or use_boundary=False/downsample=False in snapping." + ) + return ( + _seeds_from_zyx(A_zyx, seed_order), + _seeds_from_zyx(B_zyx, seed_order), + okA, + okB, + ) + + # EDT and cost on DS ROI (Seung-Lab edt if available) + t1 = perf_counter() + dist = _compute_edt(roi_ds, sampling_ds, log=log, tag="connect:EDT") + if dist.max() <= 0: + log("[connect] empty EDT in ROI; skipping connection") + return ( + _seeds_from_zyx(A_zyx, seed_order), + _seeds_from_zyx(B_zyx, seed_order), + False, + False, + ) + dn = dist / dist.max() + eps = 1e-6 + cost = np.full_like(dn, 1e12, dtype=float) + cost[roi_ds] = 1.0 / (eps + np.clip(dn[roi_ds], 0, 1) ** max(0.0, ridge_power)) + log(f"[connect] EDT/cost ready on DS-ROI | {perf_counter()-t1:.3f}s") + + # Shortest paths via MST + def _path_mask_ds(start, end): + tmcp = perf_counter() + mcp = MCP_Geometric(cost, sampling=sampling_ds) + costs, _ = mcp.find_costs([tuple(start)], find_all_ends=False) + mid = perf_counter() + v = costs[tuple(end)] + if not np.isfinite(v): + log( + f"[MCP] start={tuple(start)} -> end={tuple(end)} FAILED | setup+run={mid-tmcp:.3f}s" + ) + return None + path = np.asarray(mcp.traceback(tuple(end)), int) + m = np.zeros_like(roi_ds, bool) + m[tuple(path.T)] = True + log( + f"[MCP] start={tuple(start)} -> end={tuple(end)} OK | total={perf_counter()-tmcp:.3f}s" + ) + return m + + def _augment_team_ds(team_name, pts_ds): + if len(pts_ds) <= 1: + return np.zeros_like(roi_ds, bool), True + edges = _mst_edges_phys(pts_ds, sampling_ds) + pmask = np.zeros_like(roi_ds, bool) + ok = True + for i, j in edges: + m = _path_mask_ds(pts_ds[i], pts_ds[j]) + if m is None: + log(f"[connect:{team_name}] DS path FAILED for edge {i}-{j}") + ok = False + if refine_fullres_when_fail: + # fallback full-res EDT and path + tfr = perf_counter() + dist_fr = _compute_edt( + roi, sampling, log=log, tag="connect:EDT(fullres)" + ) + dnm = dist_fr / (dist_fr.max() if dist_fr.max() > 0 else 1.0) + cost_fr = np.full_like(dist_fr, 1e12, dtype=float) + cost_fr[roi] = 1.0 / ( + eps + np.clip(dnm[roi], 0, 1) ** max(0.0, ridge_power) + ) + s = np.array(pts_ds[i]) * np.array([sz, sy, sx]) + e = np.array(pts_ds[j]) * np.array([sz, sy, sx]) + mcp_fr = MCP_Geometric(cost_fr, sampling=sampling) + costs_fr, _ = mcp_fr.find_costs([tuple(s)], find_all_ends=False) + if np.isfinite(costs_fr[tuple(e)]): + path_fr = np.asarray(mcp_fr.traceback(tuple(e)), int) + m_fr = np.zeros_like(roi, bool) + m_fr[tuple(path_fr.T)] = True + m = m_fr[::sz, ::sy, ::sx] + ok = True + log( + f"[connect:{team_name}] fallback full-res path OK | {perf_counter()-tfr:.3f}s" + ) + else: + log( + f"[connect:{team_name}] Full-res ROI path also FAILED for edge {i}-{j}" + ) + m = None + if m is not None: + pmask |= m + return pmask, ok + + t_aug = perf_counter() + pA_ds, okA2 = _augment_team_ds("A", A_ds) + pB_ds, okB2 = _augment_team_ds("B", B_ds) + okA &= okA2 + okB &= okB2 + log(f"[connect] MST+paths built | {perf_counter()-t_aug:.3f}s") + + if not (okA and okB): + log( + "[connect] connection failed for at least one team — consider smaller downsample or refine_fullres_when_fail." + ) + return ( + _seeds_from_zyx(A_zyx, seed_order), + _seeds_from_zyx(B_zyx, seed_order), + okA, + okB, + ) + + # Up-project to full resolution and dilate + pA = _upsample_bool(pA_ds, (sz, sy, sx), roi.shape) & roi + pB = _upsample_bool(pB_ds, (sz, sy, sx), roi.shape) & roi + struc = ball(1) + tpost = perf_counter() + pA = ndi.binary_dilation(pA, structure=struc) & roi + pB = ndi.binary_dilation(pB, structure=struc) & roi + log(f"[connect] postproc dilation on paths | {perf_counter()-tpost:.3f}s") + + A_aug = set(map(tuple, A_zyx)) + B_aug = set(map(tuple, B_zyx)) + Az, Ay, Ax = np.nonzero(pA) + Bz, By, Bx = np.nonzero(pB) + for z, y, x in zip(Az, Ay, Ax): + A_aug.add((z0 + z, y0 + y, x0 + x)) + for z, y, x in zip(Bz, By, Bx): + B_aug.add((z0 + z, y0 + y, x0 + x)) + + A_aug = _seeds_from_zyx(np.array(sorted(list(A_aug)), int), seed_order) + B_aug = _seeds_from_zyx(np.array(sorted(list(B_aug)), int), seed_order) + log( + f"[connect] done; +{len(A_aug)-len(seeds_a)} vox for A, +{len(B_aug)-len(seeds_b)} for B | total {perf_counter()-t0:.3f}s" + ) + return A_aug, B_aug, True, True + + +def split_supervoxel_growing( + binary_sv: np.ndarray, + seeds_a, + seeds_b, + voxel_size=(1.0, 1.0, 1.0), + *, + # conventions / orders + vol_order: str = "xyz", + vox_order: str = "xyz", + seed_order: str = "xyz", + # geometry / cost + halo: int = 1, + gamma_neck: float = 1.6, # boundary slowdown + k_prox: float = 2.0, # proximity boost strength + lambda_prox: float = 1.0, # proximity decay + narrow_band_rel: float = 0.08, # relative difference threshold + nb_dilate: int = 1, # dilate band to stabilize + # optional: compute TA/TB on a downsampled grid + downsample_geodesic: tuple | None = None, # e.g. (1,2,2) + # post-processing / guarantees + allow_third_label: bool = True, + enforce_single_cc: bool = True, + # final validation + check_seeds_same_cc: bool = True, + raise_if_seed_split: bool = True, + raise_if_multi_cc: bool = False, + # snapping control (NEW) + snap_method: str = "kdtree", + snap_kwargs: dict | None = None, + # logging + verbose: bool = True, +): + def log(msg: str): + if verbose: + print(msg, flush=True) + + # Helpers reused from the module: _cc_label_26, _largest_component_id, _to_internal_zyx_volume, _from_internal_zyx_volume + # _seeds_to_zyx, _compute_edt, etc. are assumed available. + + # ---------- helpers ---------- + def _enforce_single_component(out_labels, lab, seed_pts_global, allow3=True): + t = perf_counter() + mask = out_labels == lab + if not np.any(mask): + return 0, 0 + comp, ncomp = _cc_label_26(mask) + if ncomp <= 1: + log(f"[single-cc:{lab}] ncomp=1 | {perf_counter()-t:.3f}s") + return 1, 0 + + keep_ids = set() + for z, y, x in seed_pts_global: + if ( + 0 <= z < out_labels.shape[0] + and 0 <= y < out_labels.shape[1] + and 0 <= x < out_labels.shape[2] + ): + if out_labels[z, y, x] == lab: + cid = comp[z, y, x] + if cid > 0: + keep_ids.add(int(cid)) + + if not keep_ids: + keep_ids = {_largest_component_id(comp)} + + lut = np.zeros(ncomp + 1, dtype=np.bool_) + lut[list(keep_ids)] = True + bad_mask = (comp > 0) & (~lut[comp]) + moved = int(bad_mask.sum()) + if allow3 and moved: + out_labels[bad_mask] = 3 + log( + f"[single-cc:{lab}] kept={len(keep_ids)}, moved_to_3={moved} | {perf_counter()-t:.3f}s" + ) + return len(keep_ids), moved + + def _resolve_label3_touching_vectorized( + out_labels, seedsA=None, seedsB=None, sampling=(1, 1, 1) + ): + t0 = perf_counter() + comp3, n3 = _cc_label_26(out_labels == 3) + n3_vox = int((out_labels == 3).sum()) + log(f"[touching] n3 comps={n3}, vox={n3_vox}") + if n3 == 0: + log(f"[touching] no label-3 components | {perf_counter()-t0:.3f}s") + return 0, 0 + + t1 = perf_counter() + struc = np.ones((3, 3, 3), bool) + N1 = ndi.binary_dilation(out_labels == 1, structure=struc) & (comp3 > 0) + N2 = ndi.binary_dilation(out_labels == 2, structure=struc) & (comp3 > 0) + + cnt1 = np.bincount(comp3[N1], minlength=n3 + 1) + cnt2 = np.bincount(comp3[N2], minlength=n3 + 1) + + assign = np.zeros(n3 + 1, dtype=np.int16) # 0=undecided, 1 or 2 otherwise + assign[cnt1 > cnt2] = 1 + assign[cnt2 > cnt1] = 2 + undec = np.where(assign[1:] == 0)[0] + 1 + log( + f"[touching] maj→1={int((assign==1).sum())}, maj→2={int((assign==2).sum())}, ties={len(undec)} | {perf_counter()-t1:.3f}s" + ) + + if ( + len(undec) > 0 + and (seedsA is not None) + and (seedsB is not None) + and len(seedsA) + and len(seedsB) + ): + t2 = perf_counter() + sA = np.zeros_like(out_labels, bool) + sA[tuple(np.array(seedsA).T)] = True + sB = np.zeros_like(out_labels, bool) + sB[tuple(np.array(seedsB).T)] = True + dA = _compute_edt(~sA, sampling, log=log, tag="split:EDT(dA)") + dB = _compute_edt(~sB, sampling, log=log, tag="split:EDT(dB)") + closer2 = (dB < dA) & (comp3 > 0) + + pref2 = np.bincount(comp3[closer2], minlength=n3 + 1) + total = np.bincount(comp3[comp3 > 0], minlength=n3 + 1) + + tie_ids = np.array(undec, dtype=int) + choose2 = pref2[tie_ids] > (total[tie_ids] - pref2[tie_ids]) + assign[tie_ids[choose2]] = 2 + assign[tie_ids[~choose2]] = 1 + log( + f"[touching] tie-break EDT done: to2={int(choose2.sum())}, to1={int((~choose2).sum())} | {perf_counter()-t2:.3f}s" + ) + + moved1 = moved2 = 0 + if (assign == 1).any(): + mask1 = assign[comp3] == 1 + moved1 = int(mask1.sum()) + out_labels[mask1] = 1 + if (assign == 2).any(): + mask2 = assign[comp3] == 2 + moved2 = int(mask2.sum()) + out_labels[mask2] = 2 + + log( + f"[touching] reassigned 3→1: {moved1}, 3→2: {moved2} | total {perf_counter()-t0:.3f}s" + ) + return moved1, moved2 + + # ---------- begin ---------- + t0 = perf_counter() + log(f"[init] vol_order={vol_order}, vox_order={vox_order}, seed_order={seed_order}") + log(f"[init] input volume shape: {binary_sv.shape}") + + # Convert input volumes and sampling into internal ZYX + sv_zyx, _ = _to_internal_zyx_volume(binary_sv, vol_order) + sampling = _to_zyx_sampling(voxel_size, vox_order) + log(f"[init] internal shape (z,y,x): {sv_zyx.shape}") + log(f"[init] sampling (z,y,x): {sampling}") + + # SNAP seeds to mask using the same KDTree-based method + A_all = _seeds_to_zyx(seeds_a, seed_order) + B_all = _seeds_to_zyx(seeds_b, seed_order) + log("[snap] snapping seeds to segment mask...") + + snap_cfg = dict( + use_boundary=True, + erosion_iters=1, + downsample=True, + downsample_mode="random", + downsample_target=50_000, + method=snap_method, # compatibility key + ) + if snap_kwargs is not None: + snap_cfg.update(snap_kwargs) + + def _snap_ZYX(pts_zyx, tagname): + if pts_zyx.size == 0: + return np.empty((0, 3), dtype=int) + # Convert ZYX -> XYZ for snapper + pts_xyz = pts_zyx[:, [2, 1, 0]] + snapped_xyz = snap_seeds_to_segment( + pts_xyz, + mask=sv_zyx, + mask_order="zyx", + voxel_size=( + sampling[2], + sampling[1], + sampling[0], + ), # convert ZYX→XYZ spacing + log=log, + tag=tagname, + **snap_cfg, + ) + return snapped_xyz[:, [2, 1, 0]] + + A = _snap_ZYX(A_all, "A@snap") + B = _snap_ZYX(B_all, "B@snap") + log(f"[seeds] A={len(A)}, B={len(B)}") + + out_zyx = np.zeros_like(sv_zyx, dtype=np.int16) + if A.size == 0 or B.size == 0 or not np.any(sv_zyx): + log("[seeds] missing seeds or empty SV; returning label=1 for entire SV") + out_zyx[sv_zyx] = 1 + return _from_internal_zyx_volume(out_zyx, vol_order) + + # Tight bbox ROI around mask with halo + t_bbox = perf_counter() + Z, Y, X = sv_zyx.shape + coords = np.argwhere(sv_zyx) + z0, y0, x0 = coords.min(0) + z1, y1, x1 = coords.max(0) + 1 + z0h = max(z0 - halo, 0) + y0h = max(y0 - halo, 0) + x0h = max(x0 - halo, 0) + z1h = min(z1 + halo, Z) + y1h = min(y1 + halo, Y) + x1h = min(x1 + halo, X) + sv = sv_zyx[z0h:z1h, y0h:y1h, x0h:x1h] + A_roi = A - np.array([z0h, y0h, x0h]) + B_roi = B - np.array([z0h, y0h, x0h]) + log( + f"[crop] ROI shape (internal): {sv.shape} (halo {halo}) | {perf_counter()-t_bbox:.3f}s" + ) + + # Build travel cost via EDT (Seung-Lab edt if available) + t1 = perf_counter() + dist = _compute_edt(sv, sampling, log=log, tag="split:EDT(mask)") + distn = dist / dist.max() if dist.max() > 0 else dist + eps = 1e-6 + speed = np.clip(distn ** max(gamma_neck, 0.0), eps, 1.0) + travel_cost = np.full_like(speed, 1e12, dtype=float) + travel_cost[sv] = 1.0 / speed[sv] + log( + f"[speed] EDT + speed map | {perf_counter()-t1:.3f}s (total {perf_counter()-t0:.3f}s)" + ) + + # Optional downsample for geodesic + use_ds = downsample_geodesic is not None + if use_ds: + dz, dy, dx = map(int, downsample_geodesic) + log(f"[geodesic] downsample grid: {downsample_geodesic}") + cost_ds = travel_cost[::dz, ::dy, ::dx] + mask_ds = sv[::dz, ::dy, ::dx] + sampling_ds = (sampling[0] * dz, sampling[1] * dy, sampling[2] * dx) + + def _to_ds(pts): + pts = (np.asarray(pts, int) // np.array([dz, dy, dx])).astype(int) + Zs, Ys, Xs = mask_ds.shape + keep = [] + for z, y, x in pts: + if 0 <= z < Zs and 0 <= y < Ys and 0 <= x < Xs and mask_ds[z, y, x]: + keep.append((z, y, x)) + return keep + + A_sub = _to_ds(A_roi) + B_sub = _to_ds(B_roi) + log(f"[geodesic] seeds on DS grid: A={len(A_sub)}, B={len(B_sub)}") + if len(A_sub) == 0 or len(B_sub) == 0: + log("[geodesic] DS removed all seeds; falling back to full-res") + use_ds = False + if not use_ds: + cost_ds = travel_cost + mask_ds = sv + sampling_ds = sampling + A_sub = [tuple(p) for p in A_roi.tolist()] + B_sub = [tuple(p) for p in B_roi.tolist()] + + # Geodesic arrival times + t2 = perf_counter() + mcpA = MCP_Geometric(cost_ds, sampling=sampling_ds) + TA, _ = mcpA.find_costs(A_sub, find_all_ends=False) + mcpB = MCP_Geometric(cost_ds, sampling=sampling_ds) + TB, _ = mcpB.find_costs(B_sub, find_all_ends=False) + TA = np.where(mask_ds, TA, np.inf) + TB = np.where(mask_ds, TB, np.inf) + log( + f"[geodesic] TA/TB computed | {perf_counter()-t2:.3f}s (total {perf_counter()-t0:.3f}s)" + ) + + # Narrow band + t3 = perf_counter() + finite = np.isfinite(TA) & np.isfinite(TB) & mask_ds + denom = TA + TB + 1e-12 + reldiff = np.zeros_like(TA) + reldiff[finite] = np.abs(TA[finite] - TB[finite]) / denom[finite] + band = finite & (reldiff <= narrow_band_rel) + if nb_dilate > 0: + band = ndi.binary_dilation(band, structure=ball(nb_dilate)) & mask_ds + if band.sum() < 64: + band = mask_ds.copy() + log("[band] tiny band -> using full ROI on current grid") + log( + f"[band] voxels: {int(band.sum())} | {perf_counter()-t3:.3f}s (total {perf_counter()-t0:.3f}s)" + ) + + # Proximity-boosted labeling + t4 = perf_counter() + denomA = 1.0 + k_prox * np.exp(-lambda_prox * np.clip(TB, 0, np.inf)) + denomB = 1.0 + k_prox * np.exp(-lambda_prox * np.clip(TA, 0, np.inf)) + CA = TA / denomA + CB = TB / denomB + sub_labels_ds = np.zeros_like(mask_ds, dtype=np.int16) + sub_labels_ds[(CA <= CB) & band] = 1 + sub_labels_ds[(CB < CA) & band] = 2 + outer = mask_ds & (sub_labels_ds == 0) + sub_labels_ds[(TA <= TB) & outer] = 1 + sub_labels_ds[(TB < TA) & outer] = 2 + for z, y, x in A_sub: + sub_labels_ds[z, y, x] = 1 + for z, y, x in B_sub: + sub_labels_ds[z, y, x] = 2 + log( + f"[label] DS labeling done | {perf_counter()-t4:.3f}s (total {perf_counter()-t0:.3f}s)" + ) + + # Upsample if needed + if use_ds: + sub_labels = _upsample_labels(sub_labels_ds, (dz, dy, dx), sv.shape) + sub_labels[~sv] = 0 + for z, y, x in A_roi: + sub_labels[z, y, x] = 1 + for z, y, x in B_roi: + sub_labels[z, y, x] = 2 + log(f"[label] upsampled DS→full ROI") + else: + sub_labels = sub_labels_ds + + # Writeback + out_zyx[sv_zyx] = 1 + out_zyx[z0h:z1h, y0h:y1h, x0h:x1h][sub_labels == 1] = 1 + out_zyx[z0h:z1h, y0h:y1h, x0h:x1h][sub_labels == 2] = 2 + log("[writeback] labels written to full volume") + + # Enforce single CC per label + if enforce_single_cc: + keptA, movedA = _enforce_single_component( + out_zyx, 1, A, allow3=allow_third_label + ) + keptB, movedB = _enforce_single_component( + out_zyx, 2, B, allow3=allow_third_label + ) + log( + f"[single-cc] label1 kept {keptA}, moved {movedA} -> 3; label2 kept {keptB}, moved {movedB} -> 3" + ) + + # Resolve 3-touching + moved1, moved2 = _resolve_label3_touching_vectorized(out_zyx, A, B, sampling) + if moved1 or moved2: + if enforce_single_cc: + keptA, movedA = _enforce_single_component( + out_zyx, 1, A, allow3=allow_third_label + ) + keptB, movedB = _enforce_single_component( + out_zyx, 2, B, allow3=allow_third_label + ) + log( + f"[single-cc 2nd] label1 kept {keptA}, moved {movedA}; label2 kept {keptB}, moved {movedB}" + ) + + # Final check + for lab in (1, 2): + _, ncomp = _cc_label_26(out_zyx == lab) + if ncomp > 1: + msg = f"[check] label {lab} has {ncomp} connected components" + if raise_if_multi_cc: + raise ValueError(msg) + else: + log(msg) + + log(f"[done] total elapsed {perf_counter()-t0:.3f}s") + return _from_internal_zyx_volume(out_zyx, vol_order) + + +def build_kdtrees_by_label( + vol: np.ndarray, + *, + background: int = 0, + leafsize: int = 16, + balanced_tree: bool = True, + compact_nodes: bool = True, + min_points: int = 1, + dtype: np.dtype = np.float32, +) -> Tuple[Dict[int, cKDTree], Dict[int, int]]: + """ + Build a cKDTree of voxel coordinates for every unique (non-background) label in a 3D volume. + + Parameters + ---------- + vol : np.ndarray + 3D label volume (e.g., shape (Z, Y, X)). Can be any integer dtype (incl. uint64). + background : int, default 0 + Label treated as background and skipped. + leafsize : int, default 16 + Passed to cKDTree (larger can be faster for queries on large trees). + balanced_tree : bool, default True + Passed to cKDTree. + compact_nodes : bool, default True + Passed to cKDTree. + min_points : int, default 1 + Skip labels with fewer than this many voxels. + dtype : np.dtype, default np.float32 + Coordinate dtype used to build the trees (lower memory than float64). + + Returns + ------- + trees : Dict[int, cKDTree] + Mapping label -> cKDTree built + from the (z, y, x) coordinates of that label’s voxels. + counts : Dict[int, int] + Mapping label -> number of voxels used to build the tree. + + Notes + ----- + - This runs in O(N log N) due to a single sort over N foreground voxels. + - Uses one pass over non-background voxels; avoids per-label boolean masking. + - Coordinates are (z, y, x) in voxel units. + """ + if vol.ndim != 3: + raise ValueError("`vol` must be a 3D array.") + Z, Y, X = vol.shape + + # Flatten once and select foreground voxels + flat = vol.ravel() + if background == 0: + nz = np.flatnonzero(flat) # fast path when background is 0 + else: + nz = np.flatnonzero(flat != background) + + if nz.size == 0: + return {}, {} + + # Labels of foreground voxels (kept as integer/uint64) + labels = flat[nz] + + # Coordinates for those voxels (computed once) + z, y, x = np.unravel_index(nz, (Z, Y, X)) + coords = np.column_stack((z, y, x)).astype(dtype, copy=False) + + # Group by label via sort (stable to preserve any incidental ordering) + order = np.argsort(labels, kind="mergesort") + labels_sorted = labels[order] + + # Find group boundaries (run-length encoding over sorted labels) + starts = np.flatnonzero(np.r_[True, labels_sorted[1:] != labels_sorted[:-1]]) + ends = np.r_[starts[1:], labels_sorted.size] + + trees: Dict[int, cKDTree] = {} + counts: Dict[int, int] = {} + + for s, e in zip(starts, ends): + lab = int(labels_sorted[s]) # Python int key (handles uint64 safely) + block = coords[order[s:e]] + n = block.shape[0] + if n < min_points: + continue + # cKDTree copies data into its own memory; no need to keep `block` afterwards. + trees[lab] = cKDTree( + block, + leafsize=leafsize, + balanced_tree=balanced_tree, + compact_nodes=compact_nodes, + ) + counts[lab] = n + + return trees, counts + + +def pairwise_min_distance_two_sets( + trees_a: Sequence[cKDTree], + trees_b: Sequence[cKDTree], + *, + max_distance: Optional[float] = None, + workers: int = -1, +) -> np.ndarray: + """ + Compute pairwise shortest distances between point sets represented by two lists + of cKDTrees. Result has shape (len(trees_a), len(trees_b)). + + Parameters + ---------- + trees_a, trees_b : sequences of cKDTree + Each tree encodes the (z,y,x) points for one segment. + max_distance : float or None + If None (default): compute exact min distances (dense, finite). + If set: compute within this cutoff using sparse_distance_matrix; pairs with + no neighbors within cutoff are set to np.inf. + workers : int + Parallelism for cKDTree.query (SciPy >= 1.6). -1 uses all cores. + + Returns + ------- + D : ndarray, shape (len(trees_a), len(trees_b)) + D[i,j] = min distance between any point in trees_a[i] and trees_b[j]. + If max_distance is not None, entries may be np.inf. + """ + A, B = len(trees_a), len(trees_b) + if A == 0 or B == 0: + return np.zeros((A, B), dtype=float) + + D = np.zeros((A, B), dtype=float) + + if max_distance is not None: + # Cutoff mode: faster when many pairs are far apart. + D.fill(np.inf) + for i in range(A): + ti = trees_a[i] + for j in range(B): + tj = trees_b[j] + s = ti.sparse_distance_matrix( + tj, max_distance, output_type="coo_matrix" + ) + if s.nnz > 0: + D[i, j] = float(s.data.min()) + return D + + # Exact mode: query points of the smaller tree into the larger tree (k=1) and take min. + for i in range(A): + ti = trees_a[i] + ni = ti.n + for j in range(B): + tj = trees_b[j] + nj = tj.n + if ni <= nj: + d, _ = tj.query(ti.data, k=1, workers=workers) + else: + d, _ = ti.query(tj.data, k=1, workers=workers) + # d can be scalar if one tree has 1 point; np.min handles both + D[i, j] = float(np.min(d)) + return D + + +def split_supervoxel_helper( + binary_seg: np.ndarray, + source_coords: np.ndarray, + sink_coords: np.ndarray, + voxel_size: tuple, + verbose: bool = False, +): + voxel_size = np.array(voxel_size) + downsample = voxel_size.max() // voxel_size + + # 1) Connect seed teams first + A_aug, B_aug, okA, okB = connect_both_seeds_via_ridge( + binary_seg, + source_coords, + sink_coords, + voxel_size=voxel_size, + downsample=downsample, + vol_order="xyz", + vox_order="xyz", + seed_order="xyz", + snap_method="kdtree", + snap_kwargs=dict( + use_boundary=False, # disables boundary-only snapping for maximum safety + downsample=False, # avoids losing candidates + method="kdtree", + ), + verbose=verbose, + ) + if not (okA and okB): + raise RuntimeError( + "In-mask connection failed for at least one team; skipping split." + ) + + # 2) Run the corridor-free splitter with same snapping settings + return split_supervoxel_growing( + binary_seg, + A_aug, + B_aug, + voxel_size=voxel_size, + vol_order="xyz", + vox_order="xyz", + seed_order="xyz", + halo=1, + gamma_neck=1.6, + narrow_band_rel=0.08, + nb_dilate=1, + downsample_geodesic=(1, 2, 2), + enforce_single_cc=True, + raise_if_seed_split=True, + raise_if_multi_cc=True, + verbose=verbose, + snap_method="kdtree", + snap_kwargs=dict( + use_boundary=False, # match the connector for consistency + downsample=False, + method="kdtree", + ), + ) diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py new file mode 100644 index 000000000..bb50505b0 --- /dev/null +++ b/pychunkedgraph/graph/edits_sv.py @@ -0,0 +1,439 @@ +""" +Manage new supervoxels after a supervoxel split. +""" + +from functools import reduce +import logging +import multiprocessing as mp +from typing import Callable, Iterable +from datetime import datetime +from collections import defaultdict, deque + +import fastremap +import numpy as np +from tqdm import tqdm +from pychunkedgraph.graph import ChunkedGraph, cache as cache_utils +from pychunkedgraph.graph.attributes import Connectivity +from pychunkedgraph.graph.chunks.utils import chunks_overlapping_bbox, get_neighbors +from pychunkedgraph.graph.cutting_sv import ( + build_kdtrees_by_label, + pairwise_min_distance_two_sets, + split_supervoxel_helper, +) +from pychunkedgraph.graph.attributes import Hierarchy, OperationLogs +from pychunkedgraph.graph.edges import Edges +from pychunkedgraph.graph.types import empty_2d +from pychunkedgraph.graph.utils import basetypes +from pychunkedgraph.graph.utils import get_local_segmentation +from pychunkedgraph.graph.utils.serializers import serialize_uint64 +from pychunkedgraph.io.edges import get_chunk_edges + + +def _get_whole_sv( + cg: ChunkedGraph, node: basetypes.NODE_ID, min_coord, max_coord +) -> set: + cx_edges = [empty_2d] + explored_chunks = set() + explored_nodes = set([node]) + queue = deque([node]) + + while len(queue) > 0: + vertex = queue.popleft() + chunk = cg.get_chunk_coordinates(vertex) + chunks = get_neighbors(chunk, min_coord=min_coord, max_coord=max_coord) + + unexplored_chunks = [] + for _chunk in chunks: + if tuple(_chunk) not in explored_chunks: + unexplored_chunks.append(tuple(_chunk)) + + edges = get_chunk_edges(cg.meta.data_source.EDGES, unexplored_chunks) + explored_chunks.update(unexplored_chunks) + _cx_edges = edges["cross"].get_pairs() + cx_edges.append(_cx_edges) + _cx_edges = np.concatenate(cx_edges) + + mask = _cx_edges[:, 0] == vertex + neighbors = _cx_edges[mask][:, 1] + + neighbor_coords = cg.get_chunk_coordinates_multiple(neighbors) + min_mask = (neighbor_coords >= min_coord).all(axis=1) + max_mask = (neighbor_coords < max_coord).all(axis=1) + neighbors = neighbors[min_mask & max_mask] + + for neighbor in neighbors: + if neighbor in explored_nodes: + continue + explored_nodes.add(neighbor) + queue.append(neighbor) + return explored_nodes + + +def _update_chunk(args): + """ + For a chunk that overlaps bounding box for supervoxel split, + If chunk contains mask for the split supervoxel, + return indices of mask, old and new supervoxel IDs from this chunk. + """ + graph_id, chunk_coord, chunk_bbox, seg, result_seg, bb_start = args + cg = ChunkedGraph(graph_id=graph_id) + x, y, z = chunk_coord + chunk_id = cg.get_chunk_id(layer=1, x=x, y=y, z=z) + + # TODO: remove these 3 lines, testing only + rr = cg.range_read_chunk(chunk_id) + max_node_id = max(rr.keys()) + cg.id_client.set_max_node_id(chunk_id, max_node_id) + + _s, _e = chunk_bbox - bb_start + og_chunk_seg = seg[_s[0] : _e[0], _s[1] : _e[1], _s[2] : _e[2]] + chunk_seg = result_seg[_s[0] : _e[0], _s[1] : _e[1], _s[2] : _e[2]] + + labels = fastremap.unique(chunk_seg[chunk_seg != 0]) + if labels.size < 2: + return None + + _indices = [] + _old_values = [] + _new_values = [] + for _id in labels: + _mask = chunk_seg == _id + if np.any(_mask): + _idx = np.unravel_index(np.flatnonzero(_mask)[0], og_chunk_seg.shape) + _og_value = og_chunk_seg[_idx] + _index = np.argwhere(_mask) + _indices.append(_index) + _ones = np.ones(len(_index), dtype=basetypes.NODE_ID) + _old_values.append(_ones * _og_value) + _new_values.append(_ones * cg.id_client.create_node_id(chunk_id)) + + _indices = np.concatenate(_indices) + (chunk_bbox[0] - bb_start) + _old_values = np.concatenate(_old_values) + _new_values = np.concatenate(_new_values) + return (_indices, _old_values, _new_values) + + +def _voxel_crop(bbs, bbe, bbs_, bbe_): + xS, yS, zS = bbs - bbs_ + xE, yE, zE = (None if i == 0 else -1 for i in bbe_ - bbe) + voxel_overlap_crop = np.s_[xS:xE, yS:yE, zS:zE] + logging.info(f"voxel_overlap_crop: {voxel_overlap_crop}") + return voxel_overlap_crop + + +def _parse_results(results, seg, bbs, bbe): + old_new_map = defaultdict(set) + for result in results: + if result: + indexer, old_values, new_values = result + seg[tuple(indexer.T)] = new_values + for old_sv, new_sv in zip(old_values, new_values): + old_new_map[old_sv].add(new_sv) + + assert np.all(seg.shape == bbe - bbs), f"{seg.shape} != {bbe - bbs}" + slices = tuple(slice(start, end) for start, end in zip(bbs, bbe)) + (slice(None),) + logging.info(f"slices {slices}") + return seg, old_new_map, slices + + +def _get_new_edges( + edges_info: tuple, + sv_ids: np.ndarray, + old_new_map: dict, + distances: np.ndarray, + dist_vec: Callable, + new_dist_vec: Callable, +): + THRESHOLD = 10 + new_edges, new_affs, new_areas = [], [], [] + edges, affinities, areas = edges_info + + for old, new in old_new_map.items(): + logging.info(f"old and new {old, new}") + new_ids = np.array(list(new), dtype=basetypes.NODE_ID) + edges_m = np.any(edges == old, axis=1) + selected_edges = edges[edges_m] + sel_m = selected_edges != old + assert np.all(np.sum(sel_m, axis=1) == 1) + + partners = selected_edges[sel_m] + active_m = np.isin(partners, sv_ids) + + logging.info(f"sv_ids: {np.sum(sv_ids > 0)}") + logging.info(f"edges: {edges.shape} {np.sum(edges_m)} {np.sum(sel_m)}") + logging.info(f"selected_edges: {selected_edges.shape}") + + # inactive + for new_id in new_ids: + _a = [[new_id] * np.sum(~active_m), partners[~active_m]] + new_edges.extend(np.array(_a, dtype=np.uint64).T) + new_affs.extend(affinities[edges_m][np.any(sel_m, axis=1)][~active_m]) + new_areas.extend(areas[edges_m][np.any(sel_m, axis=1)][~active_m]) + + # active + active_partners_ = partners[active_m] + active_affs_ = affinities[edges_m][np.any(sel_m, axis=1)][active_m] + active_areas_ = areas[edges_m][np.any(sel_m, axis=1)][active_m] + + logging.info(f"partners: {partners.shape} {active_partners_.shape}") + + active_partners = [] + active_affs = [] + active_areas = [] + for i in range(len(active_partners_)): + remapped_ = old_new_map.get(active_partners_[i], [active_partners_[i]]) + active_partners.extend(remapped_) + active_affs.extend([active_affs_[i]] * len(remapped_)) + active_areas.extend([active_areas_[i]] * len(remapped_)) + + logging.info(f"new_ids, active_partners: {new_ids, len(active_partners)}") + logging.info(f"new_dist_vec(new_ids): {new_dist_vec(new_ids)}") + logging.info(f"dist_vec(active_partners): {dist_vec(active_partners)}") + distances_ = distances[new_dist_vec(new_ids)][:, dist_vec(active_partners)].T + for i, _ in enumerate(active_partners): + new_ids_ = new_ids[distances_[i] < THRESHOLD] + if len(new_ids_): + _a = [new_ids_, [active_partners[i]] * len(new_ids_)] + new_edges.extend(np.array(_a, dtype=np.uint64).T) + new_affs.extend([active_affs[i]] * len(new_ids_)) + new_areas.extend([active_areas[i]] * len(new_ids_)) + else: + close_new_sv_id = new_ids[np.argmin(distances_[i])] + _a = [close_new_sv_id, active_partners[i]] + new_edges.append(np.array(_a, dtype=np.uint64)) + new_affs.append(active_affs[i]) + new_areas.append(active_areas[i]) + + # edges between split fragments + for i in range(len(new_ids)): + for j in range(i + 1, len(new_ids)): # includes no selfedges + _a = [new_ids[i], new_ids[j]] + new_edges.append(np.array(_a, dtype=np.uint64)) + new_affs.append(0.001) + new_areas.append(0) + + affinites = np.array(new_affs, dtype=basetypes.EDGE_AFFINITY) + areas = np.array(new_areas, dtype=basetypes.EDGE_AREA) + edges = np.array(new_edges, dtype=basetypes.NODE_ID) + edges, idx = np.unique(edges, return_index=True, axis=0) + return edges, affinites[idx], areas[idx] + + +def _update_edges( + cg: ChunkedGraph, + sv_ids: np.ndarray, + root_id: basetypes.NODE_ID, + bbox: np.ndarray, + new_seg: np.ndarray, + old_new_map: dict, +): + old_new_map = dict(old_new_map) + kdtrees, _ = build_kdtrees_by_label(new_seg) + distance_map = dict(zip(kdtrees.keys(), np.arange(len(kdtrees)))) + dist_vec = np.vectorize(distance_map.get) + + _, edges_tuple = cg.get_subgraph(root_id, bbox, bbox_is_coordinate=True) + edges_ = reduce(lambda x, y: x + y, edges_tuple, Edges([], [])) + + edges = edges_.get_pairs() + affinities = edges_.affinities + areas = edges_.areas + + edges = np.sort(edges, axis=1) + _, edges_idx = np.unique(edges, axis=0, return_index=True) + edges_idx = edges_idx[edges[edges_idx, 0] != edges[edges_idx, 1]] + + edges = edges[edges_idx] + affinities = affinities[edges_idx] + areas = areas[edges_idx] + logging.info(f"edges.shape, affinities.shape {edges.shape, affinities.shape}") + + new_ids = np.array(list(set.union(*old_new_map.values())), dtype=basetypes.NODE_ID) + new_kdtrees = [kdtrees[k] for k in new_ids] + new_disance_map = dict(zip(new_ids, np.arange(len(new_ids)))) + new_dist_vec = np.vectorize(new_disance_map.get) + distances = pairwise_min_distance_two_sets(new_kdtrees, list(kdtrees.values())) + return _get_new_edges( + (edges, affinities, areas), + sv_ids, + old_new_map, + distances, + dist_vec, + new_dist_vec, + ) + + +def _add_new_edges(cg: ChunkedGraph, edges_tuple: tuple, time_stamp: datetime = None): + edges_, affinites_, areas_ = edges_tuple + logging.info(f"new edges: {edges_.shape}") + + nodes = fastremap.unique(edges_) + chunks = cg.get_chunk_ids_from_node_ids(cg.get_parents(nodes)) + node_chunks = dict(zip(nodes, chunks)) + + edges = np.r_[edges_, edges_[:, ::-1]] + affinites = np.r_[affinites_, affinites_] + areas = np.r_[areas_, areas_] + + rows = [] + chunks_arr = fastremap.remap(edges, node_chunks) + for chunk_id in np.unique(chunks): + val_dict = {} + mask = chunks_arr[:, 0] == chunk_id + val_dict[Connectivity.SplitEdges] = edges[mask] + val_dict[Connectivity.Affinity] = affinites[mask] + val_dict[Connectivity.Area] = areas[mask] + rows.append( + cg.client.mutate_row( + serialize_uint64(chunk_id, fake_edges=True), + val_dict=val_dict, + time_stamp=time_stamp, + ) + ) + logging.info(f"writing {edges[mask].shape} edges to {chunk_id}") + return rows + + +def split_supervoxel( + cg: ChunkedGraph, + sv_id: basetypes.NODE_ID, + source_coords: np.ndarray, + sink_coords: np.ndarray, + operation_id: int, + verbose: bool = True, + time_stamp: datetime = None, +) -> dict[int, set]: + """ + Lookups coordinates of given supervoxel in segmentation. + Finds its counterparts split by chunk boundaries and splits them as a whole. + Updates the segmentation with new IDs. + """ + vol_start = cg.meta.voxel_bounds[:, 0] + vol_end = cg.meta.voxel_bounds[:, 1] + chunk_size = cg.meta.graph_config.CHUNK_SIZE + _coords = np.concatenate([source_coords, sink_coords]) + _padding = np.array([64] * 3) / cg.meta.resolution + + bbs = np.clip((np.min(_coords, 0) - _padding).astype(int), vol_start, vol_end) + bbe = np.clip((np.max(_coords, 0) + _padding).astype(int), vol_start, vol_end) + chunk_min, chunk_max = bbs // chunk_size, np.ceil(bbe / chunk_size).astype(int) + bbs, bbe = chunk_min * chunk_size, chunk_max * chunk_size + logging.info(f"cg.meta.ws_ocdbt: {cg.meta.ws_ocdbt.shape}") + logging.info(f"{chunk_size}; {_padding}; {(bbs, bbe)}; {(chunk_min, chunk_max)}") + + cut_supervoxels = _get_whole_sv(cg, sv_id, min_coord=chunk_min, max_coord=chunk_max) + supervoxel_ids = np.array(list(cut_supervoxels), dtype=basetypes.NODE_ID) + logging.info(f"{sv_id} -> {cut_supervoxels}") + + # one voxel overlap for neighbors + bbs_ = np.clip(bbs - 1, vol_start, vol_end) + bbe_ = np.clip(bbe + 1, vol_start, vol_end) + seg = get_local_segmentation(cg.meta, bbs_, bbe_).squeeze() + binary_seg = np.isin(seg, supervoxel_ids) + logging.info(f"{seg.shape}; {binary_seg.shape}; {bbs, bbe}; {bbs_, bbe_}") + + voxel_overlap_crop = _voxel_crop(bbs, bbe, bbs_, bbe_) + split_result = split_supervoxel_helper( + binary_seg[voxel_overlap_crop], + source_coords - bbs, + sink_coords - bbs, + cg.meta.resolution, + verbose=verbose, + ) + logging.info(f"split_result: {split_result.shape}") + + chunks_bbox_map = chunks_overlapping_bbox(bbs, bbe, cg.meta.graph_config.CHUNK_SIZE) + tasks = [ + (cg.graph_id, *item, seg[voxel_overlap_crop], split_result, bbs) + for item in chunks_bbox_map.items() + ] + logging.info(f"tasks count: {len(tasks)}") + with mp.Pool() as pool: + results = [*tqdm(pool.imap_unordered(_update_chunk, tasks), total=len(tasks))] + seg_cropped = seg[voxel_overlap_crop].copy() + new_seg, old_new_map, slices = _parse_results(results, seg_cropped, bbs, bbe) + + seg_roots = seg.copy() + sv_ids = fastremap.unique(seg) + roots = cg.get_roots(sv_ids) + seg_roots = fastremap.remap(seg_roots, dict(zip(sv_ids, roots)), in_place=True) + + root = cg.get_root(sv_id) + logging.info(f"root {root}") + + seg_masked = seg.copy() + seg_masked[seg_roots != root] = 0 + sv_ids = fastremap.unique(seg_masked) + + seg_masked[voxel_overlap_crop] = new_seg + edges_tuple = _update_edges( + cg, sv_ids, root, np.array([bbs, bbe]), seg_masked, old_new_map + ) + + rows0 = copy_parents_and_add_lineage(cg, operation_id, old_new_map) + rows1 = _add_new_edges(cg, edges_tuple, time_stamp=time_stamp) + rows = rows0 + rows1 + logging.info(f"{operation_id}: writing {len(rows)} new rows") + + cg.client.write(rows) + cg.meta.ws_ocdbt[slices] = new_seg[..., np.newaxis] + return old_new_map, edges_tuple + + +def copy_parents_and_add_lineage( + cg: ChunkedGraph, + operation_id: int, + old_new_map: dict, +) -> list: + """ + Copy parents column from `old_id` to each of `new_ids`. + This makes it easy to get old hierarchy with `new_ids` using an older timestamp. + Link `old_id` and `new_ids` to create a lineage at supervoxel layer. + Returns a list of mutations to be persisted. + """ + result = [] + parents = set() + old_new_map = {k: list(v) for k, v in old_new_map.items()} + parent_cells_map = cg.client.read_nodes( + node_ids=list(old_new_map.keys()), properties=Hierarchy.Parent + ) + for old_id, new_ids in old_new_map.items(): + for new_id in new_ids: + val_dict = { + Hierarchy.FormerIdentity: np.array([old_id], dtype=basetypes.NODE_ID), + OperationLogs.OperationID: operation_id, + } + result.append(cg.client.mutate_row(serialize_uint64(new_id), val_dict)) + for cell in parent_cells_map[old_id]: + cache_utils.update(cg.cache.parents_cache, [new_id], cell.value) + parents.add(cell.value) + result.append( + cg.client.mutate_row( + serialize_uint64(new_id), + {Hierarchy.Parent: cell.value}, + time_stamp=cell.timestamp, + ) + ) + val_dict = {Hierarchy.NewIdentity: np.array(new_ids, dtype=basetypes.NODE_ID)} + result.append(cg.client.mutate_row(serialize_uint64(old_id), val_dict)) + + children_cells_map = cg.client.read_nodes( + node_ids=list(parents), properties=Hierarchy.Child + ) + for parent, children_cells in children_cells_map.items(): + assert len(children_cells) == 1, children_cells + for cell in children_cells: + logging.info(f"{parent}: {cell.value}") + mask = np.isin(cell.value, list(old_new_map.keys())) + replace = np.concatenate([old_new_map[x] for x in cell.value[mask]]) + children = np.concatenate([cell.value[~mask], replace]) + logging.info(f"{parent}: {children}") + cg.cache.children_cache[parent] = children + result.append( + cg.client.mutate_row( + serialize_uint64(parent), + {Hierarchy.Child: children}, + time_stamp=cell.timestamp, + ) + ) + return result diff --git a/pychunkedgraph/graph/types.py b/pychunkedgraph/graph/types.py index 1f35e5f6b..ccbd8b8c8 100644 --- a/pychunkedgraph/graph/types.py +++ b/pychunkedgraph/graph/types.py @@ -7,7 +7,8 @@ empty_1d = np.empty(0, dtype=basetypes.NODE_ID) empty_2d = np.empty((0, 2), dtype=basetypes.NODE_ID) - +empty_affinities = np.empty(0, dtype=basetypes.EDGE_AFFINITY) +empty_areas = np.empty(0, dtype=basetypes.EDGE_AREA) """ An Agglomeration is syntactic sugar for representing diff --git a/pychunkedgraph/graph/utils/__init__.py b/pychunkedgraph/graph/utils/__init__.py index e69de29bb..c1d56e0fe 100644 --- a/pychunkedgraph/graph/utils/__init__.py +++ b/pychunkedgraph/graph/utils/__init__.py @@ -0,0 +1 @@ +from .generic import get_local_segmentation \ No newline at end of file diff --git a/pychunkedgraph/graph/utils/generic.py b/pychunkedgraph/graph/utils/generic.py index 9a2b6f979..84d5b72bf 100644 --- a/pychunkedgraph/graph/utils/generic.py +++ b/pychunkedgraph/graph/utils/generic.py @@ -3,7 +3,6 @@ TODO categorize properly """ - import datetime from typing import Dict from typing import Iterable @@ -173,9 +172,7 @@ def mask_nodes_by_bounding_box( adapt_layers = layers - 2 adapt_layers[adapt_layers < 0] = 0 fanout = meta.graph_config.FANOUT - bounding_box_layer = ( - bounding_box[None] / (fanout ** adapt_layers)[:, None, None] - ) + bounding_box_layer = bounding_box[None] / (fanout**adapt_layers)[:, None, None] bound_check = np.array( [ np.all(chunk_coordinates < bounding_box_layer[:, 1], axis=1), @@ -183,4 +180,15 @@ def mask_nodes_by_bounding_box( ] ).T - return np.all(bound_check, axis=1) \ No newline at end of file + return np.all(bound_check, axis=1) + + +def get_local_segmentation(meta, bbox_start, bbox_end) -> np.ndarray: + result = None + xL, yL, zL = bbox_start + xH, yH, zH = bbox_end + if meta.ocdbt_seg: + result = meta.ws_ocdbt[xL:xH, yL:yH, zL:zH].read().result() + else: + result = meta.cv[xL:xH, yL:yH, zL:zH] + return result diff --git a/pychunkedgraph/graph/utils/id_helpers.py b/pychunkedgraph/graph/utils/id_helpers.py index aa486ac84..bcdea92f0 100644 --- a/pychunkedgraph/graph/utils/id_helpers.py +++ b/pychunkedgraph/graph/utils/id_helpers.py @@ -10,6 +10,7 @@ import numpy as np from . import basetypes +from .generic import get_local_segmentation from ..meta import ChunkedGraphMeta from ..chunks import utils as chunk_utils @@ -140,10 +141,7 @@ def get_atomic_ids_from_coords( ] ) - local_sv_seg = meta.cv[ - bbox[0, 0] : bbox[1, 0], bbox[0, 1] : bbox[1, 1], bbox[0, 2] : bbox[1, 2] - ].squeeze() - + local_sv_seg = get_local_segmentation(meta, bbox[0], bbox[1]).squeeze() # limit get_roots calls to the relevant areas of the data lower_bs = np.floor( (np.array(coordinates_nm) - max_dist_nm) / np.array(meta.resolution) - bbox[0] diff --git a/pychunkedgraph/meshing/meshgen_utils.py b/pychunkedgraph/meshing/meshgen_utils.py index 711c09322..66c5a50c4 100644 --- a/pychunkedgraph/meshing/meshgen_utils.py +++ b/pychunkedgraph/meshing/meshgen_utils.py @@ -1,19 +1,11 @@ import re -import multiprocessing as mp -from time import time -from typing import List -from typing import Dict -from typing import Tuple from typing import Sequence from functools import lru_cache import numpy as np -from cloudvolume import CloudVolume from cloudvolume.lib import Vec -from multiwrapper import multiprocessing_utils as mu -from pychunkedgraph.graph.utils.basetypes import NODE_ID # noqa -from ..graph.types import empty_1d +from pychunkedgraph.graph.utils import get_local_segmentation def str_to_slice(slice_str: str): @@ -151,11 +143,9 @@ def get_json_info(cg): def get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx=1): - cv = CloudVolume(cg.meta.cv.cloudpath, mip=mip, fill_missing=True) mip_diff = mip - cg.meta.cv.mip - mip_chunk_size = np.array(cg.meta.graph_config.CHUNK_SIZE, dtype=int) / np.array( - [2 ** mip_diff, 2 ** mip_diff, 1] + [2**mip_diff, 2**mip_diff, 1] ) mip_chunk_size = mip_chunk_size.astype(int) @@ -169,11 +159,5 @@ def get_ws_seg_for_chunk(cg, chunk_id, mip, overlap_vx=1): cg.meta.cv.mip_voxel_offset(mip), cg.meta.cv.mip_voxel_offset(mip) + cg.meta.cv.mip_volume_size(mip), ) - - ws_seg = cv[ - chunk_start[0] : chunk_end[0], - chunk_start[1] : chunk_end[1], - chunk_start[2] : chunk_end[2], - ].squeeze() - + ws_seg = get_local_segmentation(cg.meta, chunk_start, chunk_end).squeeze() return ws_seg diff --git a/requirements.in b/requirements.in index 4fcd353ed..fe4a6352d 100644 --- a/requirements.in +++ b/requirements.in @@ -16,6 +16,9 @@ pyyaml cachetools werkzeug tensorstore +edt +connected-components-3d +scikit-image # PyPI only: cloud-files>=4.21.1 diff --git a/requirements.txt b/requirements.txt index 0eedacb31..5cf38561d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -56,6 +56,8 @@ compressed-segmentation==2.2.1 # via cloud-volume compresso==3.2.1 # via cloud-volume +connected-components-3d==3.24.0 + # via -r requirements.in crackle-codec==0.7.0 # via cloud-volume crc32c==2.3.post0 @@ -72,6 +74,8 @@ dracopy==1.3.0 # via # -r requirements.in # cloud-volume +edt==3.0.0 + # via -r requirements.in fasteners==0.19 # via cloud-files fastremap==1.14.0 @@ -161,6 +165,8 @@ grpcio-status==1.58.0 # google-cloud-pubsub idna==3.4 # via requests +imageio==2.37.0 + # via scikit-image inflection==0.5.1 # via python-jsonschema-objects iniconfig==2.0.0 @@ -181,6 +187,8 @@ jsonschema==4.19.1 # python-jsonschema-objects jsonschema-specifications==2023.7.1 # via jsonschema +lazy-loader==0.4 + # via scikit-image markdown==3.4.4 # via python-jsonschema-objects markupsafe==2.1.3 @@ -201,23 +209,30 @@ networkx==3.1 # via # -r requirements.in # cloud-volume + # scikit-image numpy==1.26.0 # via # -r requirements.in # cloud-volume # compressed-segmentation # compresso + # connected-components-3d # crackle-codec + # edt # fastremap # fpzip + # imageio # messagingclient # ml-dtypes # multiwrapper # pandas # pyspng-seunglab + # scikit-image + # scipy # simplejpeg # task-queue # tensorstore + # tifffile # zfpc # zmesh orderedmultidict==1.0.1 @@ -227,7 +242,10 @@ orjson==3.9.7 # cloud-files # task-queue packaging==23.1 - # via pytest + # via + # lazy-loader + # pytest + # scikit-image pandas==2.1.1 # via -r requirements.in pathos==0.3.1 @@ -238,7 +256,10 @@ pathos==0.3.1 pbr==5.11.1 # via task-queue pillow==10.0.1 - # via cloud-volume + # via + # cloud-volume + # imageio + # scikit-image pluggy==1.3.0 # via pytest posix-ipc==1.1.1 @@ -323,6 +344,10 @@ rsa==4.9 # google-auth s3transfer==0.6.2 # via boto3 +scikit-image==0.24.0 + # via -r requirements.in +scipy==1.16.1 + # via scikit-image simplejpeg==1.7.2 # via cloud-volume six==1.16.0 @@ -342,6 +367,8 @@ tenacity==8.2.3 # task-queue tensorstore==0.1.53 # via -r requirements.in +tifffile==2025.6.11 + # via scikit-image tqdm==4.66.1 # via # cloud-files From f6ada50d872e7063bd7188b9411e0155cde4848f Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Wed, 10 Sep 2025 21:40:37 +0000 Subject: [PATCH 4/5] feat(sv_split): sv split in frontend --- pychunkedgraph/app/segmentation/common.py | 94 ++++++++++++++++------- pychunkedgraph/graph/cutting.py | 21 ++--- pychunkedgraph/graph/exceptions.py | 30 +++++++- pychunkedgraph/graph/operation.py | 11 ++- pychunkedgraph/repair/fake_edges.py | 6 +- 5 files changed, 118 insertions(+), 44 deletions(-) diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py index 70642c9ce..6b73e0050 100644 --- a/pychunkedgraph/app/segmentation/common.py +++ b/pychunkedgraph/app/segmentation/common.py @@ -9,16 +9,13 @@ import numpy as np import pandas as pd +import fastremap from flask import current_app, g, jsonify, make_response, request from pytz import UTC from pychunkedgraph import __version__ from pychunkedgraph.app import app_utils -from pychunkedgraph.graph import ( - attributes, - cutting, - segmenthistory, -) +from pychunkedgraph.graph import attributes, cutting, segmenthistory, ChunkedGraph from pychunkedgraph.graph import ( edges as cg_edges, ) @@ -27,6 +24,7 @@ ) from pychunkedgraph.graph.analysis import pathing from pychunkedgraph.graph.attributes import OperationLogs +from pychunkedgraph.graph.edits_sv import split_supervoxel from pychunkedgraph.graph.misc import get_contact_sites from pychunkedgraph.graph.operation import GraphEditOperation from pychunkedgraph.graph.utils import basetypes @@ -393,7 +391,7 @@ def handle_merge(table_id, allow_same_segment_merge=False): current_app.operation_id = ret.operation_id if ret.new_root_ids is None: raise cg_exceptions.InternalServerError( - "Could not merge selected " "supervoxel." + f"{ret.operation_id}: Could not merge selected supervoxels." ) current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids)) @@ -407,24 +405,10 @@ def handle_merge(table_id, allow_same_segment_merge=False): ### SPLIT ---------------------------------------------------------------------- -def handle_split(table_id): - current_app.table_id = table_id - user_id = str(g.auth_user.get("id", current_app.user_id)) - - data = json.loads(request.data) - is_priority = request.args.get("priority", True, type=str2bool) - remesh = request.args.get("remesh", True, type=str2bool) - mincut = request.args.get("mincut", True, type=str2bool) - +def _get_sources_and_sinks(cg: ChunkedGraph, data): current_app.logger.debug(data) - - # Call ChunkedGraph - cg = app_utils.get_cg(table_id, skip_cache=True) node_idents = [] - node_ident_map = { - "sources": 0, - "sinks": 1, - } + node_ident_map = {"sources": 0, "sinks": 1} coords = [] node_ids = [] @@ -437,18 +421,74 @@ def handle_split(table_id): node_ids = np.array(node_ids, dtype=np.uint64) coords = np.array(coords) node_idents = np.array(node_idents) + + start = time.time() sv_ids = app_utils.handle_supervoxel_id_lookup(cg, coords, node_ids) + current_app.logger.info(f"SV lookup took {time.time() - start}s.") current_app.logger.debug( {"node_id": node_ids, "sv_id": sv_ids, "node_ident": node_idents} ) + source_ids = sv_ids[node_idents == 0] + sink_ids = sv_ids[node_idents == 1] + source_coords = coords[node_idents == 0] + sink_coords = coords[node_idents == 1] + return (source_ids, sink_ids, source_coords, sink_coords) + + +def handle_split(table_id): + current_app.table_id = table_id + user_id = str(g.auth_user.get("id", current_app.user_id)) + + data = json.loads(request.data) + is_priority = request.args.get("priority", True, type=str2bool) + remesh = request.args.get("remesh", True, type=str2bool) + mincut = request.args.get("mincut", True, type=str2bool) + + cg = app_utils.get_cg(table_id, skip_cache=True) + sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) try: ret = cg.remove_edges( user_id=user_id, - source_ids=sv_ids[node_idents == 0], - sink_ids=sv_ids[node_idents == 1], - source_coords=coords[node_idents == 0], - sink_coords=coords[node_idents == 1], + source_ids=sources, + sink_ids=sinks, + source_coords=source_coords, + sink_coords=sink_coords, + mincut=mincut, + ) + except cg_exceptions.SupervoxelSplitRequiredError as e: + current_app.logger.info(e) + sources_remapped = fastremap.remap( + sources, + e.sv_remapping, + preserve_missing_labels=True, + in_place=False, + ) + sinks_remapped = fastremap.remap( + sinks, + e.sv_remapping, + preserve_missing_labels=True, + in_place=False, + ) + overlap_mask = np.isin(sources_remapped, sinks_remapped) + for sv_to_split in np.unique(sources_remapped[overlap_mask]): + _mask0 = sources_remapped[sources_remapped == sv_to_split] + _mask1 = sinks_remapped[sinks_remapped == sv_to_split] + split_supervoxel( + cg, + sv_to_split, + source_coords[_mask0], + sink_coords[_mask1], + e.operation_id, + ) + + sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) + ret = cg.remove_edges( + user_id=user_id, + source_ids=sources, + sink_ids=sinks, + source_coords=source_coords, + sink_coords=sink_coords, mincut=mincut, ) except cg_exceptions.LockingError as e: @@ -459,7 +499,7 @@ def handle_split(table_id): current_app.operation_id = ret.operation_id if ret.new_root_ids is None: raise cg_exceptions.InternalServerError( - "Could not split selected segment groups." + f"{ret.operation_id}: Could not split selected segment groups." ) current_app.logger.debug(("after split:", ret.new_root_ids)) diff --git a/pychunkedgraph/graph/cutting.py b/pychunkedgraph/graph/cutting.py index 8b1583871..2c86c1091 100644 --- a/pychunkedgraph/graph/cutting.py +++ b/pychunkedgraph/graph/cutting.py @@ -1,23 +1,18 @@ -import collections import fastremap import numpy as np import itertools -import logging import time import graph_tool import graph_tool.flow -from typing import Dict from typing import Tuple -from typing import Optional from typing import Sequence from typing import Iterable from .utils import flatgraph -from .utils import basetypes from .utils.generic import get_bounding_box from .edges import Edges -from .exceptions import PreconditionError +from .exceptions import PreconditionError, SupervoxelSplitRequiredError from .exceptions import PostconditionError DEBUG_MODE = False @@ -116,6 +111,10 @@ def __init__( self.cross_chunk_edge_remapping, ) = merge_cross_chunk_edges_graph_tool(cg_edges, cg_affs) + # save this representative mapping for supervoxel splitting + # passed along with SupervoxelSplitRequiredError + self.sv_remapping = dict(complete_mapping) + dt = time.time() - time_start if logger is not None: logger.debug("Cross edge merging: %.2fms" % (dt * 1000)) @@ -233,9 +232,10 @@ def _augment_mincut_capacity(self): self.source_graph_ids, ) except AssertionError: - raise PreconditionError( + raise SupervoxelSplitRequiredError( "Paths between source or sink points irreparably overlap other labels from other side. " - "Check that labels are correct and consider spreading points out farther." + "Check that labels are correct and consider spreading points out farther.", + self.sv_remapping ) paths_e_s_no, paths_e_y_no, do_check = flatgraph.remove_overlapping_edges( @@ -581,11 +581,12 @@ def _sink_and_source_connectivity_sanity_check(self, cut_edge_set): # but return a flag to return a message to the user illegal_split = True else: - raise PreconditionError( + raise SupervoxelSplitRequiredError( "Failed to find a cut that separated the sources from the sinks. " "Please try another cut that partitions the sets cleanly if possible. " "If there is a clear path between all the supervoxels in each set, " - "that helps the mincut algorithm." + "that helps the mincut algorithm.", + self.sv_remapping ) except IsolatingCutException as e: if self.split_preview: diff --git a/pychunkedgraph/graph/exceptions.py b/pychunkedgraph/graph/exceptions.py index 45aa57fc7..78d4154a7 100644 --- a/pychunkedgraph/graph/exceptions.py +++ b/pychunkedgraph/graph/exceptions.py @@ -3,21 +3,25 @@ class ChunkedGraphError(Exception): """Base class for all exceptions raised by the ChunkedGraph""" + pass class LockingError(ChunkedGraphError): """Raised when a Bigtable Lock could not be acquired""" + pass class PreconditionError(ChunkedGraphError): """Raised when preconditions for Chunked Graph operations are not met""" + pass class PostconditionError(ChunkedGraphError): """Raised when postconditions for Chunked Graph operations are not met""" + pass @@ -42,7 +46,7 @@ def __init__(self, message): self.message = message def __str__(self): - return f'[{self.status_code}]: {self.message}' + return f"[{self.status_code}]: {self.message}" class ClientError(ChunkedGraphAPIError): @@ -51,21 +55,25 @@ class ClientError(ChunkedGraphAPIError): class BadRequest(ClientError): """Exception mapping a ``400 Bad Request`` response.""" + status_code = http_client.BAD_REQUEST class Unauthorized(ClientError): """Exception mapping a ``401 Unauthorized`` response.""" + status_code = http_client.UNAUTHORIZED class Forbidden(ClientError): """Exception mapping a ``403 Forbidden`` response.""" + status_code = http_client.FORBIDDEN class Conflict(ClientError): """Exception mapping a ``409 Conflict`` response.""" + status_code = http_client.CONFLICT @@ -75,9 +83,29 @@ class ServerError(ChunkedGraphAPIError): class InternalServerError(ServerError): """Exception mapping a ``500 Internal Server Error`` response.""" + status_code = http_client.INTERNAL_SERVER_ERROR class GatewayTimeout(ServerError): """Exception mapping a ``504 Gateway Timeout`` response.""" + status_code = http_client.GATEWAY_TIMEOUT + + +class SupervoxelSplitRequiredError(ChunkedGraphError): + """ + Raised when supervoxel splitting is necessary. + Edit process should catch this error and retry after supervoxel has been split. + Saves remapping required for detecting which supervoxels need to be split. + """ + + def __init__( + self, + message: str, + sv_remapping: dict, + operation_id: int | None = None, + ): + super().__init__(message) + self.sv_remapping = sv_remapping + self.operation_id = operation_id diff --git a/pychunkedgraph/graph/operation.py b/pychunkedgraph/graph/operation.py index 8c5d4484e..6b3778251 100644 --- a/pychunkedgraph/graph/operation.py +++ b/pychunkedgraph/graph/operation.py @@ -26,7 +26,7 @@ from .utils import serializers from .cache import CacheService from .cutting import run_multicut -from .exceptions import PreconditionError +from .exceptions import PreconditionError, SupervoxelSplitRequiredError from .exceptions import PostconditionError from .utils.generic import get_bounding_box as get_bbox from ..logging.log_db import TimeIt @@ -451,6 +451,10 @@ def execute( new_root_ids=new_root_ids, new_lvl2_ids=new_lvl2_ids, ) + except SupervoxelSplitRequiredError as err: + raise SupervoxelSplitRequiredError( + str(err), err.sv_remapping, operation_id=lock.operation_id + ) from err except PreconditionError as err: self.cg.cache = None raise PreconditionError(err) from err @@ -852,9 +856,10 @@ def __init__( self.path_augment = path_augment self.disallow_isolating_cut = disallow_isolating_cut if np.any(np.in1d(self.sink_ids, self.source_ids)): - raise PreconditionError( + raise SupervoxelSplitRequiredError( "Supervoxels exist in both sink and source, " - "try placing the points further apart." + "try placing the points further apart.", + None, ) ids = np.concatenate([self.source_ids, self.sink_ids]) diff --git a/pychunkedgraph/repair/fake_edges.py b/pychunkedgraph/repair/fake_edges.py index b58b93fb9..1c0e26fd2 100644 --- a/pychunkedgraph/repair/fake_edges.py +++ b/pychunkedgraph/repair/fake_edges.py @@ -9,9 +9,9 @@ from os import environ from typing import Optional -environ["BIGTABLE_PROJECT"] = "<>" -environ["BIGTABLE_INSTANCE"] = "<>" -environ["GOOGLE_APPLICATION_CREDENTIALS"] = "" +# environ["BIGTABLE_PROJECT"] = "<>" +# environ["BIGTABLE_INSTANCE"] = "<>" +# environ["GOOGLE_APPLICATION_CREDENTIALS"] = "" from pychunkedgraph.graph import edits from pychunkedgraph.graph import ChunkedGraph From e6106b093c258a4cf18b1c584c1f0c570a2bb7b3 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Thu, 9 Oct 2025 20:26:55 +0000 Subject: [PATCH 5/5] fix(sv_split): update multicut test --- pychunkedgraph/app/segmentation/common.py | 4 ++-- pychunkedgraph/graph/edits_sv.py | 9 +++++---- pychunkedgraph/tests/test_uncategorized.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py index 6b73e0050..5566e81d7 100644 --- a/pychunkedgraph/app/segmentation/common.py +++ b/pychunkedgraph/app/segmentation/common.py @@ -472,8 +472,8 @@ def handle_split(table_id): ) overlap_mask = np.isin(sources_remapped, sinks_remapped) for sv_to_split in np.unique(sources_remapped[overlap_mask]): - _mask0 = sources_remapped[sources_remapped == sv_to_split] - _mask1 = sinks_remapped[sinks_remapped == sv_to_split] + _mask0 = sources_remapped == sv_to_split + _mask1 = sinks_remapped == sv_to_split split_supervoxel( cg, sv_to_split, diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py index bb50505b0..4ac3a40f7 100644 --- a/pychunkedgraph/graph/edits_sv.py +++ b/pychunkedgraph/graph/edits_sv.py @@ -56,10 +56,11 @@ def _get_whole_sv( mask = _cx_edges[:, 0] == vertex neighbors = _cx_edges[mask][:, 1] - neighbor_coords = cg.get_chunk_coordinates_multiple(neighbors) - min_mask = (neighbor_coords >= min_coord).all(axis=1) - max_mask = (neighbor_coords < max_coord).all(axis=1) - neighbors = neighbors[min_mask & max_mask] + if len(neighbors) > 0: + neighbor_coords = cg.get_chunk_coordinates_multiple(neighbors) + min_mask = (neighbor_coords >= min_coord).all(axis=1) + max_mask = (neighbor_coords < max_coord).all(axis=1) + neighbors = neighbors[min_mask & max_mask] for neighbor in neighbors: if neighbor in explored_nodes: diff --git a/pychunkedgraph/tests/test_uncategorized.py b/pychunkedgraph/tests/test_uncategorized.py index 766b81bca..50002cdb3 100644 --- a/pychunkedgraph/tests/test_uncategorized.py +++ b/pychunkedgraph/tests/test_uncategorized.py @@ -1872,7 +1872,7 @@ def test_path_augmented_multicut(self, sv_data): cut_edges_aug = run_multicut(edges, sv_sources, sv_sinks, path_augment=True) assert cut_edges_aug.shape[0] == 350 - with pytest.raises(exceptions.PreconditionError): + with pytest.raises(exceptions.SupervoxelSplitRequiredError): run_multicut(edges, sv_sources, sv_sinks, path_augment=False)