From 3581c0ae866e77332e599c6e869064e130c9d159 Mon Sep 17 00:00:00 2001 From: xintin Date: Fri, 8 Aug 2025 17:08:25 +0000 Subject: [PATCH 1/4] block scan Signed-off-by: xintin --- tests/kernel/wave/wave_e2e_test.py | 73 ++++++-- wave_lang/kernel/ops/wave_ops.py | 2 + wave_lang/kernel/wave/decompose_scan_ops.py | 188 +++++++++++++++++++- 3 files changed, 244 insertions(+), 19 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index c3e520454..090586b1a 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -1733,18 +1733,7 @@ def test( @require_e2e -@pytest.mark.parametrize( - "shape", - [ - (1, 27), - (1, 64), - (51, 64), - (128, 64), - (1, 256), - (1, 512), - (64, 500), - ], -) +@pytest.mark.parametrize("shape", get_test_shapes("test_scanop_cumsum")) def test_scanop_cumsum(shape, run_bench): M = tkl.sym.M N = tkl.sym.N @@ -1774,7 +1763,6 @@ def test( res = tkw.cumsum(lhs, dim=N) tkw.write(res, c) - torch.manual_seed(1) input = device_randint(low=1, high=5, size=shape, dtype=torch.int32) output = device_zeros(shape, dtype=torch.int32) torch_ref = torch.cumsum((input), dim=-1, dtype=torch.int32) @@ -1794,6 +1782,65 @@ def test( assert_close(torch_ref, output, atol=1e-03, rtol=1e-05) +@require_e2e +@pytest.mark.parametrize("shape", [(256, 256)]) +def test_block_scanop_cumsum(shape, run_bench): + round_to_divisible = lambda src, denom: sympy.ceiling(src / denom) * denom + M = tkl.sym.M + N = tkl.sym.N + wave_size = 64 + num_waves = 4 + BLOCK_M = 1 + + # Distribute N dim across num_waves, and pad to disivible by wave_size. + ELEMS_PER_WAVE = round_to_divisible(sympy.ceiling(N / num_waves), wave_size) + # Minimum number of elems per wave should be size of wave. + ELEMS_PER_WAVE = sympy.Max(ELEMS_PER_WAVE, wave_size) + BLOCK_N = ELEMS_PER_WAVE * num_waves + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=wave_size, + vector_shapes={M: 1, N: ELEMS_PER_WAVE}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, ELEMS_PER_WAVE)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.i32], + ): + lhs = tkw.read(a) + res = tkw.cumsum(lhs, dim=N, block=True) + tkw.write(res, c) + + # input = device_randint(low=1, high=5, size=shape, dtype=torch.int32) + input = device_ones(shape, dtype=torch.int32) + output = device_zeros(shape, dtype=torch.int32) + torch_ref = torch.cumsum((input), dim=-1, dtype=torch.int32) + options = WaveCompileOptions( + subs={ + M: shape[0], + N: shape[1], + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + wave_runtime=True, + run_bench=run_bench, + ) + options = set_default_run_config(options) + test = wave_compile(options, test) + + test(input, output) + breakpoint() + assert_close(torch_ref, output, atol=1e-03, rtol=1e-05) + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_vector_add")[:2]) @param_bool("use_buffer_ops", "buf_ops") diff --git a/wave_lang/kernel/ops/wave_ops.py b/wave_lang/kernel/ops/wave_ops.py index 0d01460fe..48ce4b6e2 100644 --- a/wave_lang/kernel/ops/wave_ops.py +++ b/wave_lang/kernel/ops/wave_ops.py @@ -2403,11 +2403,13 @@ class ScanOp(CustomOp, ABC): arg: Source tensor/value to scan. init: Optional initial value. dim: Symbolic dimension along which to scan. + block: When set to true, scan across block, else scan across warp. """ arg: fx.Node | list[fx.Node] init: Optional[fx.Node] = None dim: Optional[IndexSymbol] = None + block: Optional[bool] = False @property def indexing_dims(self) -> list[IndexSymbol]: diff --git a/wave_lang/kernel/wave/decompose_scan_ops.py b/wave_lang/kernel/wave/decompose_scan_ops.py index bdc242801..35f77545e 100644 --- a/wave_lang/kernel/wave/decompose_scan_ops.py +++ b/wave_lang/kernel/wave/decompose_scan_ops.py @@ -2,29 +2,40 @@ from operator import ge from typing import Callable +import sympy import torch.fx as fx -from wave_lang.kernel._support.indexing import IndexSymbol -from wave_lang.kernel.wave.utils.general_utils import all_equal +from wave_lang.kernel.lang.global_symbols import SHARED_ADDRESS_SPACE +from wave_lang.kernel._support.indexing import IndexSequence, IndexSymbol +from wave_lang.kernel.wave.utils.general_utils import all_equal, delinearize_index from wave_lang.kernel.wave.utils.symbol_utils import subs_idxc +import wave_lang.kernel.lang as tkl + from .._support.dtype import i1 from .._support.tracing import CapturedTrace from ..ops.wave_ops import ( Add, + Allocate, + Conditional, Cumsum, CustomOp, Extract, + Eq, NewRegister, + NewScalar, + Placeholder, + Read, Reshape, ScanOp, SelectOp, ShuffleOp, + Write, get_custom, ) -from .constraints import HardwareConstraint +from .constraints import HardwareConstraint, WaveConstraint, WorkgroupConstraint from .utils.classes import ShuffleMode -from .utils.graph_utils import DCE +from .utils.graph_utils import DCE, get_outer_node def get_graph_node(custom: CustomOp, graph: fx.Graph) -> fx.Node: @@ -209,6 +220,136 @@ def emit_global_scan( return final_scanop_result +def emit_interwave_scan( + binary_fn, + src, + graph, + trace, + scan_dim, + num_scan_waves, + wg_constraint_map, + hardware_constraint, + local_scan_size, +): + """ + Computes prefix sum across waves within a block: + 1. Each wave writes its final scanned value to shared memory. + 2. Wave 0 computes prefix scan of those values. + 3. Each wave reads its prefix offset and adds it to its scanned values. + """ + # Possible lane_id: Mod($T0, 64) == 0 -> 63 + lane_id = ( + hardware_constraint.linearized_thread_id % hardware_constraint.threads_per_wave + ) + # Possible wave_id: [Mod(floor($T0/64), 4), 0, 0] + # [[0,3], 0, 0] + wave_id = delinearize_index( + hardware_constraint.linearized_thread_id + // hardware_constraint.threads_per_wave, + hardware_constraint.waves_per_block, + ) + scan_wg_dim = wg_constraint_map[scan_dim].workgroup_dim + # [0, 1, 2, 3] + scan_wave_id = wave_id[scan_wg_dim] + + # Allocate shared memory for per-wave totals + allocate_node = Allocate( + (scan_dim,), + (num_scan_waves,), + src.type.dtype, + SHARED_ADDRESS_SPACE, + ).add_to_graph(graph) + + # Step 1: Each wave's last lane stores its final scan result + execute_on_lane63_graph = fx.Graph() + subgraph_name = f"store_wave_sum_{src.name}" + + placeholder_src = get_graph_node( + Placeholder.from_fx_node(src), execute_on_lane63_graph + ) + placeholder_src.type = src.type + + placeholder_alloc = get_graph_node( + Placeholder.from_fx_node(get_custom(allocate_node)), execute_on_lane63_graph + ) + placeholder_alloc.type = get_custom(allocate_node).type + placeholder_alloc.meta["lifted"] = allocate_node + + write = Write(placeholder_src, placeholder_alloc, 1).add_to_graph( + execute_on_lane63_graph + ) + write.index = {scan_dim: IndexSequence(scan_wave_id, 1, 1)} + + lane_id_reg = get_graph_node(NewScalar(lane_id, tkl.i32), graph) + max_lane_id = get_graph_node( + NewScalar(hardware_constraint.threads_per_wave - 1, tkl.i32), graph + ) + is_last_lane = get_graph_node(Eq(lane_id_reg, max_lane_id), graph) + + implicit_src = get_outer_node(src) + + conditional_write = get_graph_node( + Conditional( + is_last_lane, + subgraph_name=subgraph_name, + implicit_captures=[implicit_src, allocate_node], + ), + graph, + ) + execute_on_lane63_graph.parent_op = conditional_write + trace.add_subgraph(subgraph_name, execute_on_lane63_graph) + trace.get_root_graph().subgraphs[subgraph_name] = execute_on_lane63_graph + + read_totals = Read( + allocate_node, + elements_per_thread=num_scan_waves, + _write_dependency=[conditional_write, write], + ).add_to_graph(graph) + read_totals.index = {scan_dim: IndexSequence(0, 1, 1)} + + scanned_totals_nested = emit_local_inclusive_scan( + binary_fn, [read_totals], graph, num_scan_waves + ) + scanned_totals = scanned_totals_nested[0] # flatten: list of num_scan_waves scalars + + packed_totals = Reshape( + args=scanned_totals, + target_vector_shape={scan_dim: num_scan_waves}, + ).add_to_graph(graph) + + prev_idx = sympy.Max(scan_wave_id - 1, 0) + wave_offset = Read(packed_totals, elements_per_thread=1).add_to_graph(graph) + wave_offset.index = {scan_dim: IndexSequence(prev_idx, 1, 1)} + + scan_wave_id_node = get_graph_node(NewScalar(scan_wave_id, tkl.i32), graph) + zero_scalar = get_graph_node(NewScalar(0, tkl.i32), graph) + is_wave0 = get_graph_node(Eq(scan_wave_id_node, zero_scalar), graph) + cond_vec = get_graph_node( + NewRegister(get_custom(wave_offset).type.symbolic_shape, i1, is_wave0), + graph, + ) + zero_vec = get_graph_node( + NewRegister(get_custom(wave_offset).type.symbolic_shape, src.type.dtype, 0), + graph, + ) + final_offset = get_graph_node(SelectOp(cond_vec, zero_vec, wave_offset), graph) + + updated_scalars = [ + get_graph_node(binary_fn(elem, final_offset), graph) for elem in scanned_totals + ] + final_scan_result = Reshape( + args=updated_scalars, + target_vector_shape={scan_dim: local_scan_size}, + ).add_to_graph(graph) + + custom = get_custom(src) + final_scan_result.index = custom.index + final_scan_result.expanded_dims = custom.expanded_dims + final_scan_result.vector_shapes = custom.vector_shapes + + return [final_scan_result] + + def decompose_scan_ops( trace: CapturedTrace, constraints: list, @@ -233,6 +374,13 @@ def decompose_scan_ops( c for c in constraints if isinstance(c, HardwareConstraint) ) + wave_constraint_map = { + c.dim: c for c in constraints if isinstance(c, WaveConstraint) + } + workgroup_constraint_map = { + c.dim: c for c in constraints if isinstance(c, WorkgroupConstraint) + } + subgroup_size = hardware_constraint.threads_per_wave for node in scan_nodes: @@ -241,7 +389,7 @@ def decompose_scan_ops( raise NotImplementedError(f"ScanOp '{custom}' not supported") with custom.graph.inserting_before(custom.fx_node): - scan_src, scan_acc, scan_dim = node.args + scan_src, scan_acc, scan_dim, block_scan = node.args binary_fn = Add if scan_dim is None: @@ -307,10 +455,38 @@ def decompose_scan_ops( scan_dim, ) + if block_scan: + # compute num_warps to scan across + num_scan_waves = int( + workgroup_constraint_map[scan_dim].tile_size + // wave_constraint_map[scan_dim].tile_size + ) + + if num_scan_waves > subgroup_size: + raise NotImplementedError( + "The 2nd stage butterfly shuffle reduces the" + "the reduction outputs from all the wave. Hence, can only handle at most " + "threads_per_wave number of warps." + ) + + # Scan and update output between waves, by storing individual wave result into shared memory, + # and then adding it to the rest waves to update it. + final_scan = emit_interwave_scan( + binary_fn, + global_scan[0], + custom.graph, + trace, + scan_dim, + num_scan_waves, + workgroup_constraint_map, + hardware_constraint, + local_scan_sizes[0], + ) + breakpoint() # Update the users based on the global scan `reshape` results. for user in custom.users: user.update_arg( - custom.fx_node, global_scan[user.expanded_dims[scan_dim]] + custom.fx_node, final_scan[user.expanded_dims[scan_dim]] ) DCE(trace) From f4b72d8cca65145e5231abaf649b8369a24acb75 Mon Sep 17 00:00:00 2001 From: xintin Date: Wed, 13 Aug 2025 12:48:07 +0000 Subject: [PATCH 2/4] wip Signed-off-by: xintin --- tests/kernel/wave/wave_e2e_test.py | 2 + wave_lang/kernel/wave/decompose_scan_ops.py | 191 ++++++++++++-------- 2 files changed, 114 insertions(+), 79 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 090586b1a..9a1fb4399 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -1832,6 +1832,8 @@ def test( canonicalize=True, wave_runtime=True, run_bench=run_bench, + print_ir_after=["all"], + print_ir_before=["all"], ) options = set_default_run_config(options) test = wave_compile(options, test) diff --git a/wave_lang/kernel/wave/decompose_scan_ops.py b/wave_lang/kernel/wave/decompose_scan_ops.py index 35f77545e..fddafb31f 100644 --- a/wave_lang/kernel/wave/decompose_scan_ops.py +++ b/wave_lang/kernel/wave/decompose_scan_ops.py @@ -2,7 +2,6 @@ from operator import ge from typing import Callable -import sympy import torch.fx as fx from wave_lang.kernel.lang.global_symbols import SHARED_ADDRESS_SPACE @@ -17,6 +16,7 @@ from ..ops.wave_ops import ( Add, Allocate, + Broadcast, Conditional, Cumsum, CustomOp, @@ -32,6 +32,10 @@ ShuffleOp, Write, get_custom, + Mul, + Sub, + Min, + Ge, ) from .constraints import HardwareConstraint, WaveConstraint, WorkgroupConstraint from .utils.classes import ShuffleMode @@ -220,6 +224,31 @@ def emit_global_scan( return final_scanop_result +def emit_add_with_wave_offset( + binary_fn, + src, + my_excl, + graph: fx.Graph, + scan_wave_id, +): + """ """ + + my_off_scl = get_graph_node(Extract(my_excl, [scan_wave_id]), graph) + m_dim = subs_idxc(src.type.symbolic_shape[0]) + + values = [get_graph_node(Extract(src, [i]), graph) for i in range(m_dim)] + + breakpoint() + for i in range(m_dim): + out = get_graph_node(binary_fn(values[i], my_off_scl), graph) + + src_custom = get_custom(src) + out.index = src_custom.index + out.expanded_dims = src_custom.expanded_dims + out.vector_shapes = src_custom.vector_shapes + return out + + def emit_interwave_scan( binary_fn, src, @@ -229,125 +258,130 @@ def emit_interwave_scan( num_scan_waves, wg_constraint_map, hardware_constraint, - local_scan_size, ): """ - Computes prefix sum across waves within a block: - 1. Each wave writes its final scanned value to shared memory. - 2. Wave 0 computes prefix scan of those values. - 3. Each wave reads its prefix offset and adds it to its scanned values. + Steps: + 1) Last active lane of each wave writes its per-wave total to shared. + 2) Wave 0 inclusive-scans these totals, then converts to exclusive offsets. + 3) Each wave reads its exclusive offset and adds it to its scalar. """ - # Possible lane_id: Mod($T0, 64) == 0 -> 63 + lane_id = ( hardware_constraint.linearized_thread_id % hardware_constraint.threads_per_wave ) - # Possible wave_id: [Mod(floor($T0/64), 4), 0, 0] - # [[0,3], 0, 0] wave_id = delinearize_index( hardware_constraint.linearized_thread_id // hardware_constraint.threads_per_wave, hardware_constraint.waves_per_block, ) scan_wg_dim = wg_constraint_map[scan_dim].workgroup_dim - # [0, 1, 2, 3] - scan_wave_id = wave_id[scan_wg_dim] - - # Allocate shared memory for per-wave totals - allocate_node = Allocate( - (scan_dim,), - (num_scan_waves,), - src.type.dtype, - SHARED_ADDRESS_SPACE, - ).add_to_graph(graph) + scan_wave_id = wave_id[scan_wg_dim] # 0...3 - # Step 1: Each wave's last lane stores its final scan result - execute_on_lane63_graph = fx.Graph() - subgraph_name = f"store_wave_sum_{src.name}" + src_custom = get_custom(src) - placeholder_src = get_graph_node( - Placeholder.from_fx_node(src), execute_on_lane63_graph - ) - placeholder_src.type = src.type + total_len = subs_idxc(scan_dim) # 256 - placeholder_alloc = get_graph_node( - Placeholder.from_fx_node(get_custom(allocate_node)), execute_on_lane63_graph - ) - placeholder_alloc.type = get_custom(allocate_node).type - placeholder_alloc.meta["lifted"] = allocate_node + W = hardware_constraint.threads_per_wave - write = Write(placeholder_src, placeholder_alloc, 1).add_to_graph( - execute_on_lane63_graph + ### for partial last wave + wave_start_s = get_graph_node(NewScalar(scan_wave_id * W, tkl.i32), graph) + total_len_s = get_graph_node(NewScalar(total_len, tkl.i32), graph) + W_s = get_graph_node(NewScalar(W, tkl.i32), graph) + + rem_after_start = get_graph_node(Sub(total_len_s, wave_start_s), graph) + this_wave_len = get_graph_node(Min(rem_after_start, W_s), graph) + + lane_id_s = get_graph_node(NewScalar(lane_id, tkl.i32), graph) + last_lane_idx = get_graph_node( + Sub(this_wave_len, get_graph_node(NewScalar(1, tkl.i32), graph)), graph ) - write.index = {scan_dim: IndexSequence(scan_wave_id, 1, 1)} + gt_zero = get_graph_node( + Ge(this_wave_len, get_graph_node(NewScalar(1, tkl.i32), graph)), graph + ) # i1 + eq_last = get_graph_node(Eq(lane_id_s, last_lane_idx), graph) - lane_id_reg = get_graph_node(NewScalar(lane_id, tkl.i32), graph) - max_lane_id = get_graph_node( - NewScalar(hardware_constraint.threads_per_wave - 1, tkl.i32), graph + # logical AND via multiply + is_last_lane = get_graph_node(Mul(gt_zero, eq_last), graph) + + sums_buf = Allocate( + (scan_dim,), (num_scan_waves,), src_custom.type.dtype, SHARED_ADDRESS_SPACE + ).add_to_graph(graph) + + exec_on_last = fx.Graph() + sub_store = f"store_wave_sum_{src.name}" + + ph_src = get_graph_node(Placeholder.from_fx_node(src), exec_on_last) + ph_src.type = src_custom.type + ph_sums = get_graph_node( + Placeholder.from_fx_node(get_custom(sums_buf)), exec_on_last ) - is_last_lane = get_graph_node(Eq(lane_id_reg, max_lane_id), graph) + ph_sums.type = get_custom(sums_buf).type + ph_sums.meta["lifted"] = sums_buf - implicit_src = get_outer_node(src) + write_sum = Write(ph_src, ph_sums, 1).add_to_graph(exec_on_last) + write_sum.index = {scan_dim: IndexSequence(scan_wave_id, 1, 1)} - conditional_write = get_graph_node( + cond_store = get_graph_node( Conditional( is_last_lane, - subgraph_name=subgraph_name, - implicit_captures=[implicit_src, allocate_node], + subgraph_name=sub_store, + implicit_captures=[get_outer_node(src), sums_buf], ), graph, ) - execute_on_lane63_graph.parent_op = conditional_write - trace.add_subgraph(subgraph_name, execute_on_lane63_graph) - trace.get_root_graph().subgraphs[subgraph_name] = execute_on_lane63_graph + exec_on_last.parent_op = cond_store + trace.add_subgraph(sub_store, exec_on_last) + trace.get_root_graph().subgraphs[sub_store] = exec_on_last read_totals = Read( - allocate_node, + sums_buf, elements_per_thread=num_scan_waves, - _write_dependency=[conditional_write, write], + _write_dependency=[cond_store, write_sum], ).add_to_graph(graph) read_totals.index = {scan_dim: IndexSequence(0, 1, 1)} - scanned_totals_nested = emit_local_inclusive_scan( + incl_nested = emit_local_inclusive_scan( binary_fn, [read_totals], graph, num_scan_waves ) - scanned_totals = scanned_totals_nested[0] # flatten: list of num_scan_waves scalars + incl = incl_nested[0] # list length == num_scan_waves - packed_totals = Reshape( - args=scanned_totals, - target_vector_shape={scan_dim: num_scan_waves}, + incl_vec = Reshape( + args=incl, target_vector_shape={scan_dim: num_scan_waves} ).add_to_graph(graph) + prev_idx = get_graph_node(NewScalar(scan_wave_id - 1, tkl.i32), graph) + my_incl_prev = get_graph_node(Extract(incl_vec, [prev_idx]), graph) - prev_idx = sympy.Max(scan_wave_id - 1, 0) - wave_offset = Read(packed_totals, elements_per_thread=1).add_to_graph(graph) - wave_offset.index = {scan_dim: IndexSequence(prev_idx, 1, 1)} + scan_wave_id_s = get_graph_node(NewScalar(scan_wave_id, tkl.i32), graph) + is_first_wave = get_graph_node( + Eq(scan_wave_id_s, get_graph_node(NewScalar(0, tkl.i32), graph)), graph + ) + is_first_wave_b = get_graph_node( + Broadcast(is_first_wave, target_shape=my_incl_prev.type.symbolic_shape), graph + ) + zero_dtype = get_graph_node(NewScalar(0, src_custom.type.dtype), graph) + zero_dtype_b = get_graph_node( + Broadcast(zero_dtype, target_shape=my_incl_prev.type.symbolic_shape), graph + ) - scan_wave_id_node = get_graph_node(NewScalar(scan_wave_id, tkl.i32), graph) - zero_scalar = get_graph_node(NewScalar(0, tkl.i32), graph) - is_wave0 = get_graph_node(Eq(scan_wave_id_node, zero_scalar), graph) - cond_vec = get_graph_node( - NewRegister(get_custom(wave_offset).type.symbolic_shape, i1, is_wave0), + my_excl = get_graph_node( + SelectOp(cond=is_first_wave_b, if_true=zero_dtype_b, if_false=my_incl_prev), graph, ) - zero_vec = get_graph_node( - NewRegister(get_custom(wave_offset).type.symbolic_shape, src.type.dtype, 0), - graph, + + final_scalar = emit_add_with_wave_offset( + binary_fn=binary_fn, + src=src, + my_excl=my_excl, + graph=graph, + scan_wave_id=scan_wave_id, ) - final_offset = get_graph_node(SelectOp(cond_vec, zero_vec, wave_offset), graph) - - updated_scalars = [ - get_graph_node(binary_fn(elem, final_offset), graph) for elem in scanned_totals - ] - final_scan_result = Reshape( - args=updated_scalars, - target_vector_shape={scan_dim: local_scan_size}, - ).add_to_graph(graph) + breakpoint() - custom = get_custom(src) - final_scan_result.index = custom.index - final_scan_result.expanded_dims = custom.expanded_dims - final_scan_result.vector_shapes = custom.vector_shapes + final_scalar.index = src_custom.index + final_scalar.expanded_dims = src_custom.expanded_dims + final_scalar.vector_shapes = src_custom.vector_shapes - return [final_scan_result] + return [final_scalar] def decompose_scan_ops( @@ -480,9 +514,8 @@ def decompose_scan_ops( num_scan_waves, workgroup_constraint_map, hardware_constraint, - local_scan_sizes[0], ) - breakpoint() + # Update the users based on the global scan `reshape` results. for user in custom.users: user.update_arg( From c099e7f19412854508946bddbb9bb39874b36705 Mon Sep 17 00:00:00 2001 From: xintin Date: Mon, 25 Aug 2025 18:58:32 +0000 Subject: [PATCH 3/4] draft block scan Signed-off-by: xintin --- tests/kernel/wave/wave_e2e_test.py | 4 +- wave_lang/kernel/wave/decompose_scan_ops.py | 203 +++++++++++++------- 2 files changed, 139 insertions(+), 68 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 9a1fb4399..170ad784d 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -1832,8 +1832,8 @@ def test( canonicalize=True, wave_runtime=True, run_bench=run_bench, - print_ir_after=["all"], - print_ir_before=["all"], + # print_ir_after=["all"], + # print_ir_before=["all"], ) options = set_default_run_config(options) test = wave_compile(options, test) diff --git a/wave_lang/kernel/wave/decompose_scan_ops.py b/wave_lang/kernel/wave/decompose_scan_ops.py index fddafb31f..eebc8516d 100644 --- a/wave_lang/kernel/wave/decompose_scan_ops.py +++ b/wave_lang/kernel/wave/decompose_scan_ops.py @@ -34,8 +34,8 @@ get_custom, Mul, Sub, - Min, Ge, + Gt, ) from .constraints import HardwareConstraint, WaveConstraint, WorkgroupConstraint from .utils.classes import ShuffleMode @@ -72,6 +72,38 @@ def emit_local_inclusive_scan( return result +def emit_local_inclusive_scan_block( + binary_fn: Callable, + scan_src: list[fx.Node], + graph: fx.Graph, + elements_per_thread: int, +) -> list[list[fx.Node]]: + """ + Perform local inclusive scan for `n` elements per thread. + """ + result = [] + for node in scan_src: + values = [ + get_graph_node(Extract(node, [i]), graph) + for i in range(elements_per_thread - 1) + ] + values.insert( + 0, + get_graph_node( + NewRegister(node.type.symbolic_shape, node.type.dtype, 0.0), graph + ), + ) + values[0].index = node.index + values[1].index = node.index + + for i in range(2, elements_per_thread): + values[i] = get_graph_node(binary_fn(values[i], values[i - 1]), graph) + values[i].index = node.index + + result.append(values) + return result + + def emit_variable_scan( binary_fn: Callable, src: list[list[fx.Node]], @@ -162,6 +194,10 @@ def emit_global_scan( # perform binary scan op scanop_result = get_graph_node(binary_fn(scanop_result, masked), graph) + custom = get_custom(local_scan[0][0]) + scanop_result.index = custom.index + scanop_result.expanded_dims = custom.expanded_dims + scanop_result.vector_shapes = custom.vector_shapes final_scanop_result = [scanop_result] if local_scan_size > 1: @@ -224,33 +260,54 @@ def emit_global_scan( return final_scanop_result -def emit_add_with_wave_offset( +def emit_apply_wave_offsets( binary_fn, - src, - my_excl, + src, # fx.Node, Register[(M,)].of(dtype); already intra-wave scanned + incl: list[fx.Node], # len = = num_waves; per-wave inclusive totals graph: fx.Graph, - scan_wave_id, + wave_size: int, # threads_per_wave (e.g., 64) + scan_dim, + wave_id: int, ): """ """ - my_off_scl = get_graph_node(Extract(my_excl, [scan_wave_id]), graph) - m_dim = subs_idxc(src.type.symbolic_shape[0]) + srcc = get_custom(src) + dtype = srcc.type.dtype - values = [get_graph_node(Extract(src, [i]), graph) for i in range(m_dim)] + # Resolve M (vector length) + M_sym = srcc.type.symbolic_shape + M_val = subs_idxc(M_sym) + num_waves = len(incl) - breakpoint() - for i in range(m_dim): - out = get_graph_node(binary_fn(values[i], my_off_scl), graph) + offsets_vec_excl = [get_graph_node(NewScalar(0, dtype), graph), *incl[1:]] + offsets_vec_excl[0].index = offsets_vec_excl[1].index + + # combined_offset = Reshape( + # args=offsets_vec_excl, + # target_vector_shape={scan_dim: num_waves}, + # ).add_to_graph(graph) + + for offset in offsets_vec_excl: + offset_values = [ + get_graph_node(Extract(offset, [0]), graph) for i in range(num_waves) + ] + + out = [] + + for i in range(wave_size): + # we cannot do this + cur_val = get_graph_node(Extract(src, [i]), graph) + # How to get right offset and src slice? + cur_offset = offset_values[i // wave_size] + offset_val = get_graph_node(binary_fn(cur_offset, cur_val), graph) + out.append(offset_val) - src_custom = get_custom(src) - out.index = src_custom.index - out.expanded_dims = src_custom.expanded_dims - out.vector_shapes = src_custom.vector_shapes return out def emit_interwave_scan( binary_fn, + orig_src, src, graph, trace, @@ -275,30 +332,37 @@ def emit_interwave_scan( hardware_constraint.waves_per_block, ) scan_wg_dim = wg_constraint_map[scan_dim].workgroup_dim - scan_wave_id = wave_id[scan_wg_dim] # 0...3 + scan_wave_id = wave_id[scan_wg_dim] # 0...3; [Mod(floor($T0/64), 4), 0, 0] - src_custom = get_custom(src) + src = get_graph_node( + Broadcast(src, target_shape=orig_src.type.symbolic_shape), graph + ) + src.index = orig_src.index + src.index.update({scan_dim: IndexSequence(0, 1, 1)}) + src_custom = get_custom(src) total_len = subs_idxc(scan_dim) # 256 - - W = hardware_constraint.threads_per_wave + tpw = hardware_constraint.threads_per_wave # 64 ### for partial last wave - wave_start_s = get_graph_node(NewScalar(scan_wave_id * W, tkl.i32), graph) - total_len_s = get_graph_node(NewScalar(total_len, tkl.i32), graph) - W_s = get_graph_node(NewScalar(W, tkl.i32), graph) - - rem_after_start = get_graph_node(Sub(total_len_s, wave_start_s), graph) - this_wave_len = get_graph_node(Min(rem_after_start, W_s), graph) + W_s = get_graph_node(NewScalar(tpw, tkl.i32), graph) + temp_rem_after_start = total_len - (scan_wave_id * tpw) + rem_after_start = get_graph_node(NewScalar(temp_rem_after_start, tkl.i32), graph) + max_cond = get_graph_node(Gt(W_s, rem_after_start), graph) + this_wave_len = get_graph_node( + SelectOp(cond=max_cond, if_true=W_s, if_false=rem_after_start), graph + ) + # Get last lane id lane_id_s = get_graph_node(NewScalar(lane_id, tkl.i32), graph) - last_lane_idx = get_graph_node( + # get the last lane idx; including case when it is less than 64 + prev_lane_idx = get_graph_node( Sub(this_wave_len, get_graph_node(NewScalar(1, tkl.i32), graph)), graph ) gt_zero = get_graph_node( Ge(this_wave_len, get_graph_node(NewScalar(1, tkl.i32), graph)), graph ) # i1 - eq_last = get_graph_node(Eq(lane_id_s, last_lane_idx), graph) + eq_last = get_graph_node(Eq(lane_id_s, prev_lane_idx), graph) # logical AND via multiply is_last_lane = get_graph_node(Mul(gt_zero, eq_last), graph) @@ -312,6 +376,8 @@ def emit_interwave_scan( ph_src = get_graph_node(Placeholder.from_fx_node(src), exec_on_last) ph_src.type = src_custom.type + # ph_src.meta["lifted"] = src + ph_sums = get_graph_node( Placeholder.from_fx_node(get_custom(sums_buf)), exec_on_last ) @@ -338,48 +404,52 @@ def emit_interwave_scan( elements_per_thread=num_scan_waves, _write_dependency=[cond_store, write_sum], ).add_to_graph(graph) - read_totals.index = {scan_dim: IndexSequence(0, 1, 1)} - incl_nested = emit_local_inclusive_scan( + read_totals.index = {scan_dim: IndexSequence(scan_wave_id, 1, 1)} + + incl_nested = emit_local_inclusive_scan_block( binary_fn, [read_totals], graph, num_scan_waves ) - incl = incl_nested[0] # list length == num_scan_waves - - incl_vec = Reshape( - args=incl, target_vector_shape={scan_dim: num_scan_waves} + # TODO: this is just to test; fix accessed offset + # updated_offsets = incl_nested[-1][-1] + updated_offsets = incl_nested[-1] + updated_offsets = Reshape( + args=updated_offsets, + target_vector_shape={scan_dim: num_scan_waves}, ).add_to_graph(graph) - prev_idx = get_graph_node(NewScalar(scan_wave_id - 1, tkl.i32), graph) - my_incl_prev = get_graph_node(Extract(incl_vec, [prev_idx]), graph) - - scan_wave_id_s = get_graph_node(NewScalar(scan_wave_id, tkl.i32), graph) - is_first_wave = get_graph_node( - Eq(scan_wave_id_s, get_graph_node(NewScalar(0, tkl.i32), graph)), graph - ) - is_first_wave_b = get_graph_node( - Broadcast(is_first_wave, target_shape=my_incl_prev.type.symbolic_shape), graph - ) - zero_dtype = get_graph_node(NewScalar(0, src_custom.type.dtype), graph) - zero_dtype_b = get_graph_node( - Broadcast(zero_dtype, target_shape=my_incl_prev.type.symbolic_shape), graph - ) - - my_excl = get_graph_node( - SelectOp(cond=is_first_wave_b, if_true=zero_dtype_b, if_false=my_incl_prev), - graph, - ) - - final_scalar = emit_add_with_wave_offset( - binary_fn=binary_fn, - src=src, - my_excl=my_excl, - graph=graph, - scan_wave_id=scan_wave_id, - ) - breakpoint() - final_scalar.index = src_custom.index - final_scalar.expanded_dims = src_custom.expanded_dims - final_scalar.vector_shapes = src_custom.vector_shapes + off = get_graph_node(Extract(updated_offsets, [0]), graph) + off = get_graph_node(Broadcast(off, src.type.symbolic_shape), graph) + off.index = src.index + + # list length == num_scan_waves; incl == [extract_1 (used by wave_id1), add_6(used by wave_id2), add_7, add_8] + # read_totals = [64, 64, 64, 64] + # updated_offsets = [0, 64, 128, 192] + + # offsets_vec_excl = [get_graph_node(NewScalar(0, src.type.dtype), graph), *incl[:-1]] + # Remove redundant zero_reg from the final code. + # zero_reg = get_graph_node( + # NewRegister( + # get_custom(src).type.symbolic_shape, + # get_custom(src).type.dtype, + # 0.0, + # ), + # graph, + # ) + # zero_reg.index = src.index + + # if_scan_wave_id_zero = eq(scan_wave_id, 0) + # cond_node = get_graph_node( + # NewRegister(src.type.symbolic_shape, i1, if_scan_wave_id_zero), + # graph, + # ) + # cond_node.index = src.index + + # masked = get_graph_node( + # SelectOp(cond=cond_node, if_true=zero_reg, if_false=off), graph + # ) + + final_scalar = get_graph_node(binary_fn(off, src), graph) return [final_scalar] @@ -478,7 +548,7 @@ def decompose_scan_ops( binary_fn, local_scan, custom.graph, local_scan_sizes[0] ) - global_scan = emit_global_scan( + final_scan = emit_global_scan( binary_fn, scan_src[0], local_scan, @@ -507,7 +577,8 @@ def decompose_scan_ops( # and then adding it to the rest waves to update it. final_scan = emit_interwave_scan( binary_fn, - global_scan[0], + scan_src[0], + final_scan[0], custom.graph, trace, scan_dim, From a50c7e93639077ba647852d84539847124679945 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 9 Sep 2025 18:52:33 +0000 Subject: [PATCH 4/4] draft Signed-off-by: xintin --- wave_lang/kernel/ops/wave_ops.py | 25 +++++---- wave_lang/kernel/wave/codegen/handlers.py | 61 +++++++++++++++------ wave_lang/kernel/wave/decompose_scan_ops.py | 38 ++----------- 3 files changed, 64 insertions(+), 60 deletions(-) diff --git a/wave_lang/kernel/ops/wave_ops.py b/wave_lang/kernel/ops/wave_ops.py index 48ce4b6e2..dd6b6dbc1 100644 --- a/wave_lang/kernel/ops/wave_ops.py +++ b/wave_lang/kernel/ops/wave_ops.py @@ -447,17 +447,20 @@ def new_function(*args: Any, **kwargs: dict[str, Any]): def get_custom(node: fx.Node) -> "CustomOp": """Get the corresponding CustomOp for a given fx.Node.""" - if not isinstance(node, fx.Node): - raise ValueError(f"Expected an fx.Node but got {type(node)}") - - # If the node was created as a CustomOp it has a corresponding field - if hasattr(node, "tkw_op"): - return node.tkw_op.from_fx_node(node) - if node.op == "placeholder": - return Placeholder.from_fx_node(node) - if node.op == "output": - return Output.from_fx_node(node) - return Unknown.from_fx_node(node) + try: + if not isinstance(node, fx.Node): + raise ValueError(f"Expected an fx.Node but got {type(node)}") + + # If the node was created as a CustomOp it has a corresponding field + if hasattr(node, "tkw_op"): + return node.tkw_op.from_fx_node(node) + if node.op == "placeholder": + return Placeholder.from_fx_node(node) + if node.op == "output": + return Output.from_fx_node(node) + return Unknown.from_fx_node(node) + except: + breakpoint() def has_same_custom_type(lhs_type: Memory, rhs_type: Memory) -> bool: diff --git a/wave_lang/kernel/wave/codegen/handlers.py b/wave_lang/kernel/wave/codegen/handlers.py index ad58796cb..6a6d70ead 100644 --- a/wave_lang/kernel/wave/codegen/handlers.py +++ b/wave_lang/kernel/wave/codegen/handlers.py @@ -1556,21 +1556,45 @@ def handle_workgroup_barrier(emitter: WaveEmitter, node: fx.Node): @handle_op(extract) def handle_extract(emitter: WaveEmitter, node: fx.Node): try: - register, offset = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - assert isinstance(offset, list) and len(offset) == 1 - extract_vector = cast_vector(emitter, register) - result_type = VectorType.get([1], extract_vector.type.element_type) - # Instead of using `extract_strided_slice` op, we use `extract` + `splat` - # to construct the result vector, to enable more opportunities for them to - # be fused with nearby elementwise and memory ops. - element = vector_d.extract( - extract_vector, static_position=offset, dynamic_position=[] - ) - element = vector_d.broadcast(result_type, element) + try: + register, offset = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + assert isinstance(offset, list) and len(offset) == 1 + extract_vector = cast_vector(emitter, register) + result_type = VectorType.get([1], extract_vector.type.element_type) + # Instead of using `extract_strided_slice` op, we use `extract` + `splat` + # to construct the result vector, to enable more opportunities for them to + # be fused with nearby elementwise and memory ops. + element = vector_d.extract( + extract_vector, static_position=offset, dynamic_position=[] + ) + element = vector_d.broadcast(result_type, element) + + emitter.bind_node_proxy(node, IRProxyValue(element)) + except: + try: + register, offset = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + assert isinstance(offset, list) and len(offset) == 1 + extract_vector = cast_vector(emitter, register) + result_type = VectorType.get([1], extract_vector.type.element_type) + # Instead of using `extract_strided_slice` op, we use `extract` + `splat` + # to construct the result vector, to enable more opportunities for them to + # be fused with nearby elementwise and memory ops. + from iree.compiler import ir + dyn = ir.ShapedType.get_dynamic_size() + dynamic_position=[gen_sympy_index(add_emitter_subs(emitter), x) for x in offset] + + element = vector_d.extract( + extract_vector, dynamic_position=dynamic_position, static_position=[dyn,], + ) + breakpoint() - emitter.bind_node_proxy(node, IRProxyValue(element)) + element = vector_d.broadcast(result_type, element) + + emitter.bind_node_proxy(node, IRProxyValue(element)) @handle_op(extract_slice) @@ -1644,9 +1668,12 @@ def handle_broadcast(emitter: WaveEmitter, node: fx.Node): emitter.bind_node_proxy(node, IRProxyValue(vector_src)) return - assert ( - vector_type.shape[0] == 1 - ), f"expected vector_type.shape[0] == 1 but got {vector_type}" + try: + assert ( + vector_type.shape[0] == 1 + ), f"expected vector_type.shape[0] == 1 but got {vector_type}" + except: + breakpoint() # Extract and Splat # If by chance broadcast size matches current size, we can return src. diff --git a/wave_lang/kernel/wave/decompose_scan_ops.py b/wave_lang/kernel/wave/decompose_scan_ops.py index eebc8516d..5d810dcd6 100644 --- a/wave_lang/kernel/wave/decompose_scan_ops.py +++ b/wave_lang/kernel/wave/decompose_scan_ops.py @@ -90,7 +90,7 @@ def emit_local_inclusive_scan_block( values.insert( 0, get_graph_node( - NewRegister(node.type.symbolic_shape, node.type.dtype, 0.0), graph + NewRegister(node.type.symbolic_shape, node.type.dtype, 0), graph ), ) values[0].index = node.index @@ -100,7 +100,8 @@ def emit_local_inclusive_scan_block( values[i] = get_graph_node(binary_fn(values[i], values[i - 1]), graph) values[i].index = node.index - result.append(values) + result.append(values[:4]) + return result @@ -405,50 +406,23 @@ def emit_interwave_scan( _write_dependency=[cond_store, write_sum], ).add_to_graph(graph) - read_totals.index = {scan_dim: IndexSequence(scan_wave_id, 1, 1)} + read_totals.index = {scan_dim: IndexSequence(0, 1, 1)} # scan_wave_id incl_nested = emit_local_inclusive_scan_block( binary_fn, [read_totals], graph, num_scan_waves ) # TODO: this is just to test; fix accessed offset # updated_offsets = incl_nested[-1][-1] - updated_offsets = incl_nested[-1] + updated_offsets = incl_nested[-1] # [0, 64, 128, 192] updated_offsets = Reshape( args=updated_offsets, target_vector_shape={scan_dim: num_scan_waves}, ).add_to_graph(graph) - off = get_graph_node(Extract(updated_offsets, [0]), graph) + off = get_graph_node(Extract(updated_offsets, [scan_wave_id]), graph) off = get_graph_node(Broadcast(off, src.type.symbolic_shape), graph) off.index = src.index - # list length == num_scan_waves; incl == [extract_1 (used by wave_id1), add_6(used by wave_id2), add_7, add_8] - # read_totals = [64, 64, 64, 64] - # updated_offsets = [0, 64, 128, 192] - - # offsets_vec_excl = [get_graph_node(NewScalar(0, src.type.dtype), graph), *incl[:-1]] - # Remove redundant zero_reg from the final code. - # zero_reg = get_graph_node( - # NewRegister( - # get_custom(src).type.symbolic_shape, - # get_custom(src).type.dtype, - # 0.0, - # ), - # graph, - # ) - # zero_reg.index = src.index - - # if_scan_wave_id_zero = eq(scan_wave_id, 0) - # cond_node = get_graph_node( - # NewRegister(src.type.symbolic_shape, i1, if_scan_wave_id_zero), - # graph, - # ) - # cond_node.index = src.index - - # masked = get_graph_node( - # SelectOp(cond=cond_node, if_true=zero_reg, if_false=off), graph - # ) - final_scalar = get_graph_node(binary_fn(off, src), graph) return [final_scalar]