Skip to content

Commit

Permalink
Parallelize intersect_edges with chunks as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
Huite committed Oct 13, 2024
1 parent 49b46ce commit 11911f7
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 55 deletions.
3 changes: 2 additions & 1 deletion numba_celltree/celltree.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def intersect_edges(
``np.linalg.norm(intersections[:, 1] - intersections[:, 0], axis=1)``.
"""
edge_coords = cast_edges(edge_coords)
return locate_edges(edge_coords, self.celltree_data)
n_chunks = nb.get_num_threads()
return locate_edges(edge_coords, self.celltree_data, n_chunks)

def compute_barycentric_weights(
self,
Expand Down
127 changes: 73 additions & 54 deletions numba_celltree/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def locate_box(
stack[0] = 0
size = 1
count = 0
length = len(indices)
capacity = len(indices)

while size > 0:
node_index, size = pop(stack, size)
Expand All @@ -160,7 +160,7 @@ def locate_box(
# of drawing the re-allocation logic in here makes a
# significant runtime difference; seems like numba can
# optimize this form better.
if indices_size >= length:
if indices_size >= capacity:
return -1, indices_size
indices[indices_size, 0] = index
indices[indices_size, 1] = bbox_index
Expand Down Expand Up @@ -242,20 +242,22 @@ def locate_edge(
tree: CellTreeData,
indices: IntArray,
intersections: FloatArray,
store_intersection: bool,
indices_size: int,
index: int,
):
# Check if the entire mesh intersects with the line segment at all
tree_bbox = as_box(tree.bbox)
tree_intersects, _, _ = cohen_sutherland_line_box_clip(a, b, tree_bbox)
if not tree_intersects:
return 0
return 0, indices_size

V = to_vector(a, b)
stack = allocate_stack()
polygon_work_array = allocate_polygon()
stack[0] = 0
size = 1
count = 0
capacity = len(indices)

while size > 0:
node_index, size = pop(stack, size)
Expand All @@ -273,12 +275,16 @@ def locate_edge(
)
face_intersects, c, d = cyrus_beck_line_polygon_clip(a, b, polygon)
if face_intersects:
if store_intersection:
indices[count] = bbox_index
intersections[count, 0, 0] = c.x
intersections[count, 0, 1] = c.y
intersections[count, 1, 0] = d.x
intersections[count, 1, 1] = d.y
# If insufficient capacity, exit.
if indices_size >= capacity:
return -1, indices_size
indices[indices_size, 0] = index
indices[indices_size, 1] = bbox_index
intersections[count, 0, 0] = c.x
intersections[count, 0, 1] = c.y
intersections[count, 1, 0] = d.x
intersections[count, 1, 1] = d.y
indices_size += 1
count += 1
continue

Expand Down Expand Up @@ -331,55 +337,68 @@ def locate_edge(
elif right:
stack, size = push(stack, right_child, size)

return count
return count, indices_size


@nb.njit(parallel=PARALLEL, cache=True)
def locate_edges(
@nb.njit(cache=True)
def locate_edges_helper(
edge_coords: FloatArray,
tree: CellTreeData,
):
# Numba does not support a concurrent list or bag like stucture:
# https://github.com/numba/numba/issues/5878
# (Standard lists are not thread safe.)
# To support parallel execution, we're stuck with numpy arrays therefore.
# Since we don't know the number of contained bounding boxes, we traverse
# the tree twice: first to count, then allocate, then another time to
# actually store the indices.
# The cost of traversing twice is roughly a factor two. Since many
# computers can parallellize over more than two threads, counting first --
# which enables parallelization -- should still result in a net speed up.
n_edge = edge_coords.shape[0]
counts = np.empty(n_edge + 1, dtype=IntDType)
int_dummy = np.empty((0,), dtype=IntDType)
float_dummy = np.empty((0, 0, 0), dtype=FloatDType)
counts[0] = 0
# First run a count so we can allocate afterwards
for i in nb.prange(n_edge): # pylint: disable=not-an-iterable
a = as_point(edge_coords[i, 0])
b = as_point(edge_coords[i, 1])
counts[i + 1] = locate_edge(a, b, tree, int_dummy, float_dummy, False)

# Run a cumulative sum
total = 0
for i in range(1, n_edge + 1):
total += counts[i]
counts[i] = total

# Now allocate appropriately
ii = np.empty(total, dtype=IntDType)
jj = np.empty(total, dtype=IntDType)
xy = np.empty((total, 2, 2), dtype=FloatDType)
for i in nb.prange(n_edge): # pylint: disable=not-an-iterable
start = counts[i]
end = counts[i + 1]
ii[start:end] = i
indices = jj[start:end]
intersections = xy[start:end]
a = as_point(edge_coords[i, 0])
b = as_point(edge_coords[i, 1])
locate_edge(a, b, tree, indices, intersections, True)
offset: int,
) -> IntArray:
n_edge = len(edge_coords)
# Ensure the initial indices array isn't too small.
n = max(n_edge, 256)
indices = np.empty((n, 2), dtype=IntDType)
xy = np.empty((n, 2, 2), dtype=FloatDType)

total_count = 0
indices_size = 0
for edge_index in range(n_edge):
a = as_point(edge_coords[edge_index, 0])
b = as_point(edge_coords[edge_index, 1])

while True:
count, indices_size = locate_edge(
a, b, tree, indices, xy, indices_size, edge_index + offset
)
if count != -1:
break
# Not enough capacity: grow capacity, discard partial work, retry.
indices_size = total_count
indices = grow(indices)
xy = grow(xy)

total_count += count

return indices, xy, total_count


@nb.njit(cache=True, parallel=PARALLEL)
def locate_edges(box_coords: FloatArray, tree: CellTreeData, n_chunks: int):
chunks = np.array_split(box_coords, n_chunks)
offsets = np.zeros(n_chunks, dtype=IntDType)
for i, chunk in enumerate(chunks[:-1]):
offsets[i + 1] = offsets[i] + len(chunk)

# Setup (dummy) typed lists for numba to store parallel results.
indices = [np.empty((0, 2), dtype=IntDType) for _ in range(n_chunks)]
intersections = [np.empty((0, 2, 2), dtype=FloatDType) for _ in range(n_chunks)]
counts = np.empty(n_chunks, dtype=IntDType)
for i in nb.prange(n_chunks):
indices[i], intersections[i], counts[i] = locate_edges_helper(
chunks[i], tree, offsets[i]
)

total_size = sum(counts)
xy = np.empty((total_size, 2, 2), dtype=IntDType)
start = 0
for i, size in enumerate(counts):
end = start + size
xy[start:end] = intersections[i][:size]
start = end

ii, jj = concatenate_indices(indices, counts)
return ii, jj, xy


Expand Down

0 comments on commit 11911f7

Please sign in to comment.