Skip to content

Commit

Permalink
feat: support large axis sizes for statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
william-silversmith committed Mar 28, 2024
1 parent 62ab41e commit 9cfc47f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 19 deletions.
17 changes: 17 additions & 0 deletions automated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,23 @@ def test_statistics(order):
stats = cc3d.statistics(labels)
assert np.all(stats["centroids"][0] == np.array([255.5,255.5,255.5]))

@pytest.mark.parametrize("order", ["C", "F"])
def test_statistics_big(order):
labels = np.zeros((50,66000,1), dtype=np.uint8, order=order)
labels[10:20,10:20,:2] = 1
labels[40:50,40:50,:2] = 2

stats = cc3d.statistics(labels)
assert stats["voxel_counts"][1] == 100
assert stats["voxel_counts"][2] == 10 * 10 * 1

labels = np.zeros((66000,60,1), dtype=np.uint8, order=order)
labels[10:20,10:20,:2] = 1
labels[40:50,40:50,:2] = 2

stats = cc3d.statistics(labels)
assert stats["voxel_counts"][1] == 100
assert stats["voxel_counts"][2] == 10 * 10 * 1

@pytest.mark.parametrize("connectivity", (8, 18, 26))
@pytest.mark.parametrize("dtype", TEST_TYPES)
Expand Down
63 changes: 44 additions & 19 deletions cc3d.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ ctypedef fused INTEGER:
int32_t
int64_t

ctypedef fused BBOX_T:
uint16_t
uint32_t

class DimensionError(Exception):
"""The array has the wrong number of dimensions."""
pass
Expand Down Expand Up @@ -656,16 +660,37 @@ def _statistics(
f"Statistics can only be computed on volumes containing labels with values lower than the number of voxels. Max: {N}"
)
if np.any(np.array([sx, sy, sz]) > np.iinfo(np.uint16).max):
raise ValueError(f"Only dimensions shorter than 65536 are supported. Shape: {sx}, {sy}, {sz}")
cdef cnp.ndarray[uint16_t] bounding_boxes16
cdef cnp.ndarray[uint32_t] bounding_boxes32
if np.any(np.array([sx,sy,sz]) > np.iinfo(np.uint16).max):
bounding_boxes32 = np.zeros(6 * (N + 1), dtype=np.uint32)
return _statistics_helper(out_labels, no_slice_conversion, bounding_boxes32, N)
else:
bounding_boxes16 = np.zeros(6 * (N + 1), dtype=np.uint16)
return _statistics_helper(out_labels, no_slice_conversion, bounding_boxes16, N)
@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def _statistics_helper(
cnp.ndarray[UINT, ndim=3] out_labels,
native_bool no_slice_conversion,
cnp.ndarray[BBOX_T, ndim=1] bounding_boxes,
uint64_t N
):
cdef uint64_t voxels = out_labels.size;
cdef uint64_t sx = out_labels.shape[0]
cdef uint64_t sy = out_labels.shape[1]
cdef uint64_t sz = out_labels.shape[2]
cdef cnp.ndarray[uint32_t] counts = np.zeros(N + 1, dtype=np.uint32)
cdef cnp.ndarray[uint16_t] bounding_boxes = np.zeros(6 * (N + 1), dtype=np.uint16)
cdef cnp.ndarray[double] centroids = np.zeros(3 * (N + 1), dtype=np.float64)
cdef uint16_t x = 0
cdef uint16_t y = 0
cdef uint16_t z = 0
cdef BBOX_T x = 0
cdef BBOX_T y = 0
cdef BBOX_T z = 0
cdef uint64_t label = 0
Expand All @@ -677,12 +702,12 @@ def _statistics(
for x in range(sx):
label = <uint64_t>out_labels[x,y,z]
counts[label] += 1
bounding_boxes[6 * label + 0] = <uint16_t>min(bounding_boxes[6 * label + 0], x)
bounding_boxes[6 * label + 1] = <uint16_t>max(bounding_boxes[6 * label + 1], x)
bounding_boxes[6 * label + 2] = <uint16_t>min(bounding_boxes[6 * label + 2], y)
bounding_boxes[6 * label + 3] = <uint16_t>max(bounding_boxes[6 * label + 3], y)
bounding_boxes[6 * label + 4] = <uint16_t>min(bounding_boxes[6 * label + 4], z)
bounding_boxes[6 * label + 5] = <uint16_t>max(bounding_boxes[6 * label + 5], z)
bounding_boxes[6 * label + 0] = <BBOX_T>min(bounding_boxes[6 * label + 0], x)
bounding_boxes[6 * label + 1] = <BBOX_T>max(bounding_boxes[6 * label + 1], x)
bounding_boxes[6 * label + 2] = <BBOX_T>min(bounding_boxes[6 * label + 2], y)
bounding_boxes[6 * label + 3] = <BBOX_T>max(bounding_boxes[6 * label + 3], y)
bounding_boxes[6 * label + 4] = <BBOX_T>min(bounding_boxes[6 * label + 4], z)
bounding_boxes[6 * label + 5] = <BBOX_T>max(bounding_boxes[6 * label + 5], z)
centroids[3 * label + 0] += <double>x
centroids[3 * label + 1] += <double>y
centroids[3 * label + 2] += <double>z
Expand All @@ -692,12 +717,12 @@ def _statistics(
for z in range(sz):
label = <uint64_t>out_labels[x,y,z]
counts[label] += 1
bounding_boxes[6 * label + 0] = <uint16_t>min(bounding_boxes[6 * label + 0], x)
bounding_boxes[6 * label + 1] = <uint16_t>max(bounding_boxes[6 * label + 1], x)
bounding_boxes[6 * label + 2] = <uint16_t>min(bounding_boxes[6 * label + 2], y)
bounding_boxes[6 * label + 3] = <uint16_t>max(bounding_boxes[6 * label + 3], y)
bounding_boxes[6 * label + 4] = <uint16_t>min(bounding_boxes[6 * label + 4], z)
bounding_boxes[6 * label + 5] = <uint16_t>max(bounding_boxes[6 * label + 5], z)
bounding_boxes[6 * label + 0] = <BBOX_T>min(bounding_boxes[6 * label + 0], x)
bounding_boxes[6 * label + 1] = <BBOX_T>max(bounding_boxes[6 * label + 1], x)
bounding_boxes[6 * label + 2] = <BBOX_T>min(bounding_boxes[6 * label + 2], y)
bounding_boxes[6 * label + 3] = <BBOX_T>max(bounding_boxes[6 * label + 3], y)
bounding_boxes[6 * label + 4] = <BBOX_T>min(bounding_boxes[6 * label + 4], z)
bounding_boxes[6 * label + 5] = <BBOX_T>max(bounding_boxes[6 * label + 5], z)
centroids[3 * label + 0] += <double>x
centroids[3 * label + 1] += <double>y
centroids[3 * label + 2] += <double>z
Expand Down Expand Up @@ -729,7 +754,7 @@ def _statistics(
slices.append((slice(xs, int(xe+1)), slice(ys, int(ye+1)), slice(zs, int(ze+1))))
else:
slices.append(None)
output["bounding_boxes"] = slices
return output
Expand Down

0 comments on commit 9cfc47f

Please sign in to comment.