Skip to content

Conversation

willghatch
Copy link
Contributor

@willghatch willghatch commented Aug 28, 2025

Motivated by my topk implementation which needs broadcasted selection.

@willghatch willghatch requested a review from raikonenfnu August 28, 2025 23:48
@@ -217,7 +217,7 @@ def verify_nodes(trace: CapturedTrace, constraints: list[Constraint]):
continue
if isinstance(custom.type, DataType):
continue
assert custom.index, f"Index not set for node {custom.fx_node}: {custom}"
assert custom.index != None, f"Index not set for node {custom.fx_node}: {custom}"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned about this change. There is a lit test that's hitting this with a SelectOp that wasn't before. The index is an empty dictionary, which makes sense because in the case of that test it's a scalar. But I worry that this could now be missing something useful that it used to catch.

@@ -557,11 +556,12 @@ def handle_generic_atomic(emitter: WaveEmitter, node: fx.Node):


def get_rank(mlir_type):
if not isinstance(mlir_type, ShapedType):
if hasattr(mlir_type, "rank"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another questionable change. The vector types that I'm generating are not caught here. Maybe I should find another type to test. But this was also a mildly frustrating problem to track down to having a whitelisted set of values that didn't catch mine, so honestly I think I prefer the duck-type version.

@willghatch willghatch force-pushed the willghatch/select-broadcast branch 2 times, most recently from 5551fed to 3960fe7 Compare August 29, 2025 00:00
Comment on lines +141 to +143
# allow shape to be provided as a tuple instead of as individual elements, to work around lack of unpacking in subscripts for Python 3.10
if len(shape) == 1 and isinstance(shape[0], tuple):
shape = shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a LIT test that show case this use case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used in the new infer_broadcast_shape. IE the type is Register[broadcast_shape, dtype], because old versions of python still don't support Register[*broadcast_shape, dtype].

t_rank = get_rank(t_value.type)
f_rank = get_rank(f_value.type)

max_rank = max(c_rank, t_rank, f_rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it will be better if we handle the broadcasting inside Wave passes as opposed to the codegen handlers. This matches closer with what we are doing with BinaryOps as seen in https://github.com/iree-org/wave/blob/864a9f8283272bf8f6bd1d0ab775a690c3afde37/wave_lang/kernel/wave/analysis/index_sequence_analysis.py#L949C5-L968

Copy link
Contributor

@raikonenfnu raikonenfnu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PR comment, can you add:

  1. Motivation on why you created this PR
  2. List of changes and it's intuition

@willghatch willghatch force-pushed the willghatch/select-broadcast branch from 3960fe7 to bf45667 Compare September 5, 2025 20:54
@willghatch
Copy link
Contributor Author

@raikonenfnu Thanks for the comments, and especially about a better place to do broadcasting. Here is... well, it's basically a full rewrite to now do the broadcasting in index_sequencing_analysis.py.

Motivated by my topk implementation which needs broadcasted selection.
Adds broadcasting for SelectOp alongside broadcasting for binary ops.

Signed-off-by: William G Hatch <[email protected]>
@willghatch willghatch force-pushed the willghatch/select-broadcast branch from bf45667 to 2cc0683 Compare September 5, 2025 21:01
willghatch added a commit that referenced this pull request Sep 9, 2025
This implements a TopkOp with a similar shape to the ReduceOp.

The algorithm it uses is similar to
https://github.com/sgl-project/sglang/blob/6337d9057c78fc29b31d15bec892bb9013201f6d/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu#L185 -- it does a reduction to find the maximum value and index, then masks the value at that index and repeats.

There are some things cargo-culted from the implementation of ReduceOp.  I don't yet understand the placeholder stuff that is done when creating a conditional region during the rewriting.

This PR depends on #251

Signed-off-by: William G Hatch <[email protected]>
@willghatch willghatch mentioned this pull request Sep 9, 2025
willghatch added a commit that referenced this pull request Sep 9, 2025
This implements a TopkOp with a similar shape to the ReduceOp.

The algorithm it uses is similar to
https://github.com/sgl-project/sglang/blob/6337d9057c78fc29b31d15bec892bb9013201f6d/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu#L185 -- it does a reduction to find the maximum value and index, then masks the value at that index and repeats.

There are some things cargo-culted from the implementation of ReduceOp.  I don't yet understand the placeholder stuff that is done when creating a conditional region during the rewriting.

This PR depends on #251

Signed-off-by: William G Hatch <[email protected]>
willghatch added a commit that referenced this pull request Sep 9, 2025
This implements a TopkOp with a similar shape to the ReduceOp.

The algorithm it uses is similar to
https://github.com/sgl-project/sglang/blob/6337d9057c78fc29b31d15bec892bb9013201f6d/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu#L185 -- it does a reduction to find the maximum value and index, then masks the value at that index and repeats.

There are some things cargo-culted from the implementation of ReduceOp.  I don't yet understand the placeholder stuff that is done when creating a conditional region during the rewriting.

This PR depends on #251

Signed-off-by: William G Hatch <[email protected]>
Copy link
Contributor

@raikonenfnu raikonenfnu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take a closer look tomorrow, but I really like how we have generalized handling of broadcastOp/i.e using same fn for selectOp and binaryOp! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants