Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 67 additions & 27 deletions pychunkedgraph/app/segmentation/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@

import numpy as np
import pandas as pd
import fastremap
from flask import current_app, g, jsonify, make_response, request
from pytz import UTC

from pychunkedgraph import __version__
from pychunkedgraph.app import app_utils
from pychunkedgraph.graph import (
attributes,
cutting,
segmenthistory,
)
from pychunkedgraph.graph import attributes, cutting, segmenthistory, ChunkedGraph
from pychunkedgraph.graph import (
edges as cg_edges,
)
Expand All @@ -27,6 +24,7 @@
)
from pychunkedgraph.graph.analysis import pathing
from pychunkedgraph.graph.attributes import OperationLogs
from pychunkedgraph.graph.edits_sv import split_supervoxel
from pychunkedgraph.graph.misc import get_contact_sites
from pychunkedgraph.graph.operation import GraphEditOperation
from pychunkedgraph.graph.utils import basetypes
Expand Down Expand Up @@ -393,7 +391,7 @@ def handle_merge(table_id, allow_same_segment_merge=False):
current_app.operation_id = ret.operation_id
if ret.new_root_ids is None:
raise cg_exceptions.InternalServerError(
"Could not merge selected " "supervoxel."
f"{ret.operation_id}: Could not merge selected supervoxels."
)

current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids))
Expand All @@ -407,24 +405,10 @@ def handle_merge(table_id, allow_same_segment_merge=False):
### SPLIT ----------------------------------------------------------------------


def handle_split(table_id):
current_app.table_id = table_id
user_id = str(g.auth_user.get("id", current_app.user_id))

data = json.loads(request.data)
is_priority = request.args.get("priority", True, type=str2bool)
remesh = request.args.get("remesh", True, type=str2bool)
mincut = request.args.get("mincut", True, type=str2bool)

def _get_sources_and_sinks(cg: ChunkedGraph, data):
current_app.logger.debug(data)

# Call ChunkedGraph
cg = app_utils.get_cg(table_id, skip_cache=True)
node_idents = []
node_ident_map = {
"sources": 0,
"sinks": 1,
}
node_ident_map = {"sources": 0, "sinks": 1}
coords = []
node_ids = []

Expand All @@ -437,18 +421,74 @@ def handle_split(table_id):
node_ids = np.array(node_ids, dtype=np.uint64)
coords = np.array(coords)
node_idents = np.array(node_idents)

start = time.time()
sv_ids = app_utils.handle_supervoxel_id_lookup(cg, coords, node_ids)
current_app.logger.info(f"SV lookup took {time.time() - start}s.")
current_app.logger.debug(
{"node_id": node_ids, "sv_id": sv_ids, "node_ident": node_idents}
)

source_ids = sv_ids[node_idents == 0]
sink_ids = sv_ids[node_idents == 1]
source_coords = coords[node_idents == 0]
sink_coords = coords[node_idents == 1]
return (source_ids, sink_ids, source_coords, sink_coords)


def handle_split(table_id):
current_app.table_id = table_id
user_id = str(g.auth_user.get("id", current_app.user_id))

data = json.loads(request.data)
is_priority = request.args.get("priority", True, type=str2bool)
remesh = request.args.get("remesh", True, type=str2bool)
mincut = request.args.get("mincut", True, type=str2bool)

cg = app_utils.get_cg(table_id, skip_cache=True)
sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data)
try:
ret = cg.remove_edges(
user_id=user_id,
source_ids=sv_ids[node_idents == 0],
sink_ids=sv_ids[node_idents == 1],
source_coords=coords[node_idents == 0],
sink_coords=coords[node_idents == 1],
source_ids=sources,
sink_ids=sinks,
source_coords=source_coords,
sink_coords=sink_coords,
mincut=mincut,
)
except cg_exceptions.SupervoxelSplitRequiredError as e:
current_app.logger.info(e)
sources_remapped = fastremap.remap(
sources,
e.sv_remapping,
preserve_missing_labels=True,
in_place=False,
)
sinks_remapped = fastremap.remap(
sinks,
e.sv_remapping,
preserve_missing_labels=True,
in_place=False,
)
overlap_mask = np.isin(sources_remapped, sinks_remapped)
for sv_to_split in np.unique(sources_remapped[overlap_mask]):
_mask0 = sources_remapped == sv_to_split
_mask1 = sinks_remapped == sv_to_split
split_supervoxel(
cg,
sv_to_split,
source_coords[_mask0],
sink_coords[_mask1],
e.operation_id,
)

sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data)
ret = cg.remove_edges(
user_id=user_id,
source_ids=sources,
sink_ids=sinks,
source_coords=source_coords,
sink_coords=sink_coords,
mincut=mincut,
)
except cg_exceptions.LockingError as e:
Expand All @@ -459,7 +499,7 @@ def handle_split(table_id):
current_app.operation_id = ret.operation_id
if ret.new_root_ids is None:
raise cg_exceptions.InternalServerError(
"Could not split selected segment groups."
f"{ret.operation_id}: Could not split selected segment groups."
)

current_app.logger.debug(("after split:", ret.new_root_ids))
Expand Down
19 changes: 19 additions & 0 deletions pychunkedgraph/graph/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ class Connectivity:
serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2)),
)

# new edges as a result of supervoxel split
SplitEdges = _Attribute(
key=b"split_edges",
family_id="4",
serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID, shape=(-1, 2)),
)


class Hierarchy:
Child = _Attribute(
Expand Down Expand Up @@ -160,6 +167,18 @@ class Hierarchy:
serializer=serializers.NumPyValue(dtype=basetypes.NODE_ID),
)

FormerIdentity = _Attribute(
key=b"former_ids",
family_id="0",
serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID),
)

NewIdentity = _Attribute(
key=b"new_ids",
family_id="0",
serializer=serializers.NumPyArray(dtype=basetypes.NODE_ID),
)


class GraphMeta:
key = b"meta"
Expand Down
42 changes: 32 additions & 10 deletions pychunkedgraph/graph/chunkedgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,22 +626,44 @@ def get_subgraph_leaves(
self, node_id_or_ids, bbox, bbox_is_coordinate, False, True
)

def get_fake_edges(
def get_edited_edges(
self, chunk_ids: np.ndarray, time_stamp: datetime.datetime = None
) -> typing.Dict:
"""
Edges stored within a pcg that were created as a result of edits.
Either 'fake' edges that were adding for a merge edit;
Or 'split' edges resulting from a supervoxel split.
"""
result = {}
fake_edges_d = self.client.read_nodes(
properties = [
attributes.Connectivity.FakeEdges,
attributes.Connectivity.SplitEdges,
attributes.Connectivity.Affinity,
attributes.Connectivity.Area,
]
_edges_d = self.client.read_nodes(
node_ids=chunk_ids,
properties=attributes.Connectivity.FakeEdges,
properties=properties,
end_time=time_stamp,
end_time_inclusive=True,
fake_edges=True,
)
for id_, val in fake_edges_d.items():
edges = np.concatenate(
[np.array(e.value, dtype=basetypes.NODE_ID, copy=False) for e in val]
)
result[id_] = Edges(edges[:, 0], edges[:, 1])
for id_, val in _edges_d.items():
edges = val.get(attributes.Connectivity.FakeEdges, [])
edges = np.concatenate([types.empty_2d, *[e.value for e in edges]])
fake_edges_ = Edges(edges[:, 0], edges[:, 1])

edges = val.get(attributes.Connectivity.SplitEdges, [])
edges = np.concatenate([types.empty_2d, *[e.value for e in edges]])

aff = val.get(attributes.Connectivity.Affinity, [])
aff = np.concatenate([types.empty_affinities, *[e.value for e in aff]])

areas = val.get(attributes.Connectivity.Area, [])
areas = np.concatenate([types.empty_areas, *[e.value for e in areas]])
split_edges_ = Edges(edges[:, 0], edges[:, 1], affinities=aff, areas=areas)

result[id_] = fake_edges_ + split_edges_
return result

def copy_fake_edges(self, chunk_id: np.uint64) -> None:
Expand Down Expand Up @@ -680,10 +702,10 @@ def get_l2_agglomerations(
if self.mock_edges is None:
edges_d = self.read_chunk_edges(chunk_ids)

fake_edges = self.get_fake_edges(chunk_ids)
edited_edges = self.get_edited_edges(chunk_ids)
all_chunk_edges = reduce(
lambda x, y: x + y,
chain(edges_d.values(), fake_edges.values()),
chain(edges_d.values(), edited_edges.values()),
Edges([], []),
)
if self.mock_edges is not None:
Expand Down
54 changes: 51 additions & 3 deletions pychunkedgraph/graph/chunks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,7 @@ def _compute_chunk_id(
z: int,
) -> np.uint64:
s_bits_per_dim = meta.bitmasks[layer]
if not (
x < 2**s_bits_per_dim and y < 2**s_bits_per_dim and z < 2**s_bits_per_dim
):
if not (x < 2**s_bits_per_dim and y < 2**s_bits_per_dim and z < 2**s_bits_per_dim):
raise ValueError(
f"Coordinate is out of range \
layer: {layer} bits/dim {s_bits_per_dim}. \
Expand Down Expand Up @@ -238,3 +236,53 @@ def get_bounding_children_chunks(
if return_unique:
return np.unique(result, axis=0) if result.size else result
return result


def chunks_overlapping_bbox(bbox_min, bbox_max, chunk_size) -> dict:
"""
Find octree chunks overlapping with a bounding box in 3D
and return a dictionary mapping chunk indices to clipped bounding boxes.
"""
bbox_min = np.asarray(bbox_min, dtype=int)
bbox_max = np.asarray(bbox_max, dtype=int)
chunk_size = np.asarray(chunk_size, dtype=int)

start_idx = np.floor_divide(bbox_min, chunk_size).astype(int)
end_idx = np.floor_divide(bbox_max, chunk_size).astype(int)

ix = np.arange(start_idx[0], end_idx[0] + 1)
iy = np.arange(start_idx[1], end_idx[1] + 1)
iz = np.arange(start_idx[2], end_idx[2] + 1)
grid = np.stack(np.meshgrid(ix, iy, iz, indexing="ij"), axis=-1, dtype=int)
grid = grid.reshape(-1, 3)

chunk_min = grid * chunk_size
chunk_max = chunk_min + chunk_size
clipped_min = np.maximum(chunk_min, bbox_min)
clipped_max = np.minimum(chunk_max, bbox_max)
return {
tuple(idx): np.stack([cmin, cmax], axis=0, dtype=int)
for idx, cmin, cmax in zip(grid, clipped_min, clipped_max)
}


def get_neighbors(coord, inclusive: bool = True, min_coord=None, max_coord=None):
"""
Get all valid coordinates in the 3×3×3 cube around a given chunk,
including the chunk itself (if inclusive=True),
respecting bounding box constraints.
"""
offsets = np.array(np.meshgrid([-1, 0, 1], [-1, 0, 1], [-1, 0, 1])).T.reshape(-1, 3)
if not inclusive:
offsets = offsets[~np.all(offsets == 0, axis=1)]

neighbors = np.array(coord) + offsets
if min_coord is None:
min_coord = (0, 0, 0)
min_coord = np.array(min_coord)
neighbors = neighbors[(neighbors >= min_coord).all(axis=1)]

if max_coord is not None:
max_coord = np.array(max_coord)
neighbors = neighbors[(neighbors <= max_coord).all(axis=1)]
return neighbors
6 changes: 5 additions & 1 deletion pychunkedgraph/graph/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,17 @@ class ClientWithIDGen(SimpleClient):
"""

@abstractmethod
def create_node_ids(self, chunk_id):
def create_node_ids(self, chunk_id, size: int):
"""Generate a range of unique IDs in the chunk."""

@abstractmethod
def create_node_id(self, chunk_id):
"""Generate a unique ID in the chunk."""

@abstractmethod
def set_max_node_id(self, chunk_id, node_id):
"""Gets the current maximum node ID in the chunk."""

@abstractmethod
def get_max_node_id(self, chunk_id):
"""Gets the current maximum node ID in the chunk."""
Expand Down
11 changes: 11 additions & 0 deletions pychunkedgraph/graph/client/bigtable/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,17 @@ def create_node_id(
"""Generate a unique node ID in the chunk."""
return self.create_node_ids(chunk_id, 1, root_chunk=root_chunk)[0]

def set_max_node_id(
self, chunk_id: np.uint64, node_id: np.uint64
) -> basetypes.NODE_ID:
"""Set max segment ID for a given chunk."""
size = int(np.uint64(chunk_id) ^ np.uint64(node_id))
key = serialize_uint64(chunk_id, counter=True)
column = attributes.Concurrency.Counter
row = self._table.append_row(key)
row.increment_cell_value(column.family_id, column.key, size)
row = row.commit()

def get_max_node_id(
self, chunk_id: basetypes.CHUNK_ID, root_chunk=False
) -> basetypes.NODE_ID:
Expand Down
Loading