-
Notifications
You must be signed in to change notification settings - Fork 23
implement broadcasting for SelectOp #251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -217,7 +217,7 @@ def verify_nodes(trace: CapturedTrace, constraints: list[Constraint]): | |||
continue | |||
if isinstance(custom.type, DataType): | |||
continue | |||
assert custom.index, f"Index not set for node {custom.fx_node}: {custom}" | |||
assert custom.index != None, f"Index not set for node {custom.fx_node}: {custom}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm concerned about this change. There is a lit test that's hitting this with a SelectOp that wasn't before. The index is an empty dictionary, which makes sense because in the case of that test it's a scalar. But I worry that this could now be missing something useful that it used to catch.
@@ -557,11 +556,12 @@ def handle_generic_atomic(emitter: WaveEmitter, node: fx.Node): | |||
|
|||
|
|||
def get_rank(mlir_type): | |||
if not isinstance(mlir_type, ShapedType): | |||
if hasattr(mlir_type, "rank"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is another questionable change. The vector types that I'm generating are not caught here. Maybe I should find another type to test. But this was also a mildly frustrating problem to track down to having a whitelisted set of values that didn't catch mine, so honestly I think I prefer the duck-type version.
5551fed
to
3960fe7
Compare
# allow shape to be provided as a tuple instead of as individual elements, to work around lack of unpacking in subscripts for Python 3.10 | ||
if len(shape) == 1 and isinstance(shape[0], tuple): | ||
shape = shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a LIT test that show case this use case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used in the new infer_broadcast_shape
. IE the type is Register[broadcast_shape, dtype]
, because old versions of python still don't support Register[*broadcast_shape, dtype]
.
t_rank = get_rank(t_value.type) | ||
f_rank = get_rank(f_value.type) | ||
|
||
max_rank = max(c_rank, t_rank, f_rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it will be better if we handle the broadcasting inside Wave passes as opposed to the codegen handlers. This matches closer with what we are doing with BinaryOps as seen in https://github.com/iree-org/wave/blob/864a9f8283272bf8f6bd1d0ab775a690c3afde37/wave_lang/kernel/wave/analysis/index_sequence_analysis.py#L949C5-L968
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the PR comment, can you add:
- Motivation on why you created this PR
- List of changes and it's intuition
3960fe7
to
bf45667
Compare
@raikonenfnu Thanks for the comments, and especially about a better place to do broadcasting. Here is... well, it's basically a full rewrite to now do the broadcasting in index_sequencing_analysis.py. |
Motivated by my topk implementation which needs broadcasted selection. Adds broadcasting for SelectOp alongside broadcasting for binary ops. Signed-off-by: William G Hatch <[email protected]>
bf45667
to
2cc0683
Compare
This implements a TopkOp with a similar shape to the ReduceOp. The algorithm it uses is similar to https://github.com/sgl-project/sglang/blob/6337d9057c78fc29b31d15bec892bb9013201f6d/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu#L185 -- it does a reduction to find the maximum value and index, then masks the value at that index and repeats. There are some things cargo-culted from the implementation of ReduceOp. I don't yet understand the placeholder stuff that is done when creating a conditional region during the rewriting. This PR depends on #251 Signed-off-by: William G Hatch <[email protected]>
This implements a TopkOp with a similar shape to the ReduceOp. The algorithm it uses is similar to https://github.com/sgl-project/sglang/blob/6337d9057c78fc29b31d15bec892bb9013201f6d/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu#L185 -- it does a reduction to find the maximum value and index, then masks the value at that index and repeats. There are some things cargo-culted from the implementation of ReduceOp. I don't yet understand the placeholder stuff that is done when creating a conditional region during the rewriting. This PR depends on #251 Signed-off-by: William G Hatch <[email protected]>
This implements a TopkOp with a similar shape to the ReduceOp. The algorithm it uses is similar to https://github.com/sgl-project/sglang/blob/6337d9057c78fc29b31d15bec892bb9013201f6d/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu#L185 -- it does a reduction to find the maximum value and index, then masks the value at that index and repeats. There are some things cargo-culted from the implementation of ReduceOp. I don't yet understand the placeholder stuff that is done when creating a conditional region during the rewriting. This PR depends on #251 Signed-off-by: William G Hatch <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will take a closer look tomorrow, but I really like how we have generalized handling of broadcastOp/i.e using same fn for selectOp and binaryOp! :)
Motivated by my topk implementation which needs broadcasted selection.