diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index c3e52045..170ad784 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -1733,18 +1733,7 @@ def test( @require_e2e -@pytest.mark.parametrize( - "shape", - [ - (1, 27), - (1, 64), - (51, 64), - (128, 64), - (1, 256), - (1, 512), - (64, 500), - ], -) +@pytest.mark.parametrize("shape", get_test_shapes("test_scanop_cumsum")) def test_scanop_cumsum(shape, run_bench): M = tkl.sym.M N = tkl.sym.N @@ -1774,7 +1763,6 @@ def test( res = tkw.cumsum(lhs, dim=N) tkw.write(res, c) - torch.manual_seed(1) input = device_randint(low=1, high=5, size=shape, dtype=torch.int32) output = device_zeros(shape, dtype=torch.int32) torch_ref = torch.cumsum((input), dim=-1, dtype=torch.int32) @@ -1794,6 +1782,67 @@ def test( assert_close(torch_ref, output, atol=1e-03, rtol=1e-05) +@require_e2e +@pytest.mark.parametrize("shape", [(256, 256)]) +def test_block_scanop_cumsum(shape, run_bench): + round_to_divisible = lambda src, denom: sympy.ceiling(src / denom) * denom + M = tkl.sym.M + N = tkl.sym.N + wave_size = 64 + num_waves = 4 + BLOCK_M = 1 + + # Distribute N dim across num_waves, and pad to disivible by wave_size. + ELEMS_PER_WAVE = round_to_divisible(sympy.ceiling(N / num_waves), wave_size) + # Minimum number of elems per wave should be size of wave. + ELEMS_PER_WAVE = sympy.Max(ELEMS_PER_WAVE, wave_size) + BLOCK_N = ELEMS_PER_WAVE * num_waves + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=wave_size, + vector_shapes={M: 1, N: ELEMS_PER_WAVE}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, ELEMS_PER_WAVE)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.i32], + ): + lhs = tkw.read(a) + res = tkw.cumsum(lhs, dim=N, block=True) + tkw.write(res, c) + + # input = device_randint(low=1, high=5, size=shape, dtype=torch.int32) + input = device_ones(shape, dtype=torch.int32) + output = device_zeros(shape, dtype=torch.int32) + torch_ref = torch.cumsum((input), dim=-1, dtype=torch.int32) + options = WaveCompileOptions( + subs={ + M: shape[0], + N: shape[1], + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + wave_runtime=True, + run_bench=run_bench, + # print_ir_after=["all"], + # print_ir_before=["all"], + ) + options = set_default_run_config(options) + test = wave_compile(options, test) + + test(input, output) + breakpoint() + assert_close(torch_ref, output, atol=1e-03, rtol=1e-05) + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_vector_add")[:2]) @param_bool("use_buffer_ops", "buf_ops") diff --git a/wave_lang/kernel/ops/wave_ops.py b/wave_lang/kernel/ops/wave_ops.py index 0d01460f..dd6b6dbc 100644 --- a/wave_lang/kernel/ops/wave_ops.py +++ b/wave_lang/kernel/ops/wave_ops.py @@ -447,17 +447,20 @@ def new_function(*args: Any, **kwargs: dict[str, Any]): def get_custom(node: fx.Node) -> "CustomOp": """Get the corresponding CustomOp for a given fx.Node.""" - if not isinstance(node, fx.Node): - raise ValueError(f"Expected an fx.Node but got {type(node)}") - - # If the node was created as a CustomOp it has a corresponding field - if hasattr(node, "tkw_op"): - return node.tkw_op.from_fx_node(node) - if node.op == "placeholder": - return Placeholder.from_fx_node(node) - if node.op == "output": - return Output.from_fx_node(node) - return Unknown.from_fx_node(node) + try: + if not isinstance(node, fx.Node): + raise ValueError(f"Expected an fx.Node but got {type(node)}") + + # If the node was created as a CustomOp it has a corresponding field + if hasattr(node, "tkw_op"): + return node.tkw_op.from_fx_node(node) + if node.op == "placeholder": + return Placeholder.from_fx_node(node) + if node.op == "output": + return Output.from_fx_node(node) + return Unknown.from_fx_node(node) + except: + breakpoint() def has_same_custom_type(lhs_type: Memory, rhs_type: Memory) -> bool: @@ -2403,11 +2406,13 @@ class ScanOp(CustomOp, ABC): arg: Source tensor/value to scan. init: Optional initial value. dim: Symbolic dimension along which to scan. + block: When set to true, scan across block, else scan across warp. """ arg: fx.Node | list[fx.Node] init: Optional[fx.Node] = None dim: Optional[IndexSymbol] = None + block: Optional[bool] = False @property def indexing_dims(self) -> list[IndexSymbol]: diff --git a/wave_lang/kernel/wave/codegen/handlers.py b/wave_lang/kernel/wave/codegen/handlers.py index ad58796c..6a6d70ea 100644 --- a/wave_lang/kernel/wave/codegen/handlers.py +++ b/wave_lang/kernel/wave/codegen/handlers.py @@ -1556,21 +1556,45 @@ def handle_workgroup_barrier(emitter: WaveEmitter, node: fx.Node): @handle_op(extract) def handle_extract(emitter: WaveEmitter, node: fx.Node): try: - register, offset = node.args - except ValueError as e: - raise ValidationError("Malformed arguments") from e - assert isinstance(offset, list) and len(offset) == 1 - extract_vector = cast_vector(emitter, register) - result_type = VectorType.get([1], extract_vector.type.element_type) - # Instead of using `extract_strided_slice` op, we use `extract` + `splat` - # to construct the result vector, to enable more opportunities for them to - # be fused with nearby elementwise and memory ops. - element = vector_d.extract( - extract_vector, static_position=offset, dynamic_position=[] - ) - element = vector_d.broadcast(result_type, element) + try: + register, offset = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + assert isinstance(offset, list) and len(offset) == 1 + extract_vector = cast_vector(emitter, register) + result_type = VectorType.get([1], extract_vector.type.element_type) + # Instead of using `extract_strided_slice` op, we use `extract` + `splat` + # to construct the result vector, to enable more opportunities for them to + # be fused with nearby elementwise and memory ops. + element = vector_d.extract( + extract_vector, static_position=offset, dynamic_position=[] + ) + element = vector_d.broadcast(result_type, element) + + emitter.bind_node_proxy(node, IRProxyValue(element)) + except: + try: + register, offset = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + assert isinstance(offset, list) and len(offset) == 1 + extract_vector = cast_vector(emitter, register) + result_type = VectorType.get([1], extract_vector.type.element_type) + # Instead of using `extract_strided_slice` op, we use `extract` + `splat` + # to construct the result vector, to enable more opportunities for them to + # be fused with nearby elementwise and memory ops. + from iree.compiler import ir + dyn = ir.ShapedType.get_dynamic_size() + dynamic_position=[gen_sympy_index(add_emitter_subs(emitter), x) for x in offset] + + element = vector_d.extract( + extract_vector, dynamic_position=dynamic_position, static_position=[dyn,], + ) + breakpoint() - emitter.bind_node_proxy(node, IRProxyValue(element)) + element = vector_d.broadcast(result_type, element) + + emitter.bind_node_proxy(node, IRProxyValue(element)) @handle_op(extract_slice) @@ -1644,9 +1668,12 @@ def handle_broadcast(emitter: WaveEmitter, node: fx.Node): emitter.bind_node_proxy(node, IRProxyValue(vector_src)) return - assert ( - vector_type.shape[0] == 1 - ), f"expected vector_type.shape[0] == 1 but got {vector_type}" + try: + assert ( + vector_type.shape[0] == 1 + ), f"expected vector_type.shape[0] == 1 but got {vector_type}" + except: + breakpoint() # Extract and Splat # If by chance broadcast size matches current size, we can return src. diff --git a/wave_lang/kernel/wave/decompose_scan_ops.py b/wave_lang/kernel/wave/decompose_scan_ops.py index bdc24280..5d810dcd 100644 --- a/wave_lang/kernel/wave/decompose_scan_ops.py +++ b/wave_lang/kernel/wave/decompose_scan_ops.py @@ -4,27 +4,42 @@ import torch.fx as fx -from wave_lang.kernel._support.indexing import IndexSymbol -from wave_lang.kernel.wave.utils.general_utils import all_equal +from wave_lang.kernel.lang.global_symbols import SHARED_ADDRESS_SPACE +from wave_lang.kernel._support.indexing import IndexSequence, IndexSymbol +from wave_lang.kernel.wave.utils.general_utils import all_equal, delinearize_index from wave_lang.kernel.wave.utils.symbol_utils import subs_idxc +import wave_lang.kernel.lang as tkl + from .._support.dtype import i1 from .._support.tracing import CapturedTrace from ..ops.wave_ops import ( Add, + Allocate, + Broadcast, + Conditional, Cumsum, CustomOp, Extract, + Eq, NewRegister, + NewScalar, + Placeholder, + Read, Reshape, ScanOp, SelectOp, ShuffleOp, + Write, get_custom, + Mul, + Sub, + Ge, + Gt, ) -from .constraints import HardwareConstraint +from .constraints import HardwareConstraint, WaveConstraint, WorkgroupConstraint from .utils.classes import ShuffleMode -from .utils.graph_utils import DCE +from .utils.graph_utils import DCE, get_outer_node def get_graph_node(custom: CustomOp, graph: fx.Graph) -> fx.Node: @@ -57,6 +72,39 @@ def emit_local_inclusive_scan( return result +def emit_local_inclusive_scan_block( + binary_fn: Callable, + scan_src: list[fx.Node], + graph: fx.Graph, + elements_per_thread: int, +) -> list[list[fx.Node]]: + """ + Perform local inclusive scan for `n` elements per thread. + """ + result = [] + for node in scan_src: + values = [ + get_graph_node(Extract(node, [i]), graph) + for i in range(elements_per_thread - 1) + ] + values.insert( + 0, + get_graph_node( + NewRegister(node.type.symbolic_shape, node.type.dtype, 0), graph + ), + ) + values[0].index = node.index + values[1].index = node.index + + for i in range(2, elements_per_thread): + values[i] = get_graph_node(binary_fn(values[i], values[i - 1]), graph) + values[i].index = node.index + + result.append(values[:4]) + + return result + + def emit_variable_scan( binary_fn: Callable, src: list[list[fx.Node]], @@ -147,6 +195,10 @@ def emit_global_scan( # perform binary scan op scanop_result = get_graph_node(binary_fn(scanop_result, masked), graph) + custom = get_custom(local_scan[0][0]) + scanop_result.index = custom.index + scanop_result.expanded_dims = custom.expanded_dims + scanop_result.vector_shapes = custom.vector_shapes final_scanop_result = [scanop_result] if local_scan_size > 1: @@ -209,6 +261,173 @@ def emit_global_scan( return final_scanop_result +def emit_apply_wave_offsets( + binary_fn, + src, # fx.Node, Register[(M,)].of(dtype); already intra-wave scanned + incl: list[fx.Node], # len = = num_waves; per-wave inclusive totals + graph: fx.Graph, + wave_size: int, # threads_per_wave (e.g., 64) + scan_dim, + wave_id: int, +): + """ """ + + srcc = get_custom(src) + dtype = srcc.type.dtype + + # Resolve M (vector length) + M_sym = srcc.type.symbolic_shape + M_val = subs_idxc(M_sym) + num_waves = len(incl) + + offsets_vec_excl = [get_graph_node(NewScalar(0, dtype), graph), *incl[1:]] + offsets_vec_excl[0].index = offsets_vec_excl[1].index + + # combined_offset = Reshape( + # args=offsets_vec_excl, + # target_vector_shape={scan_dim: num_waves}, + # ).add_to_graph(graph) + + for offset in offsets_vec_excl: + offset_values = [ + get_graph_node(Extract(offset, [0]), graph) for i in range(num_waves) + ] + + out = [] + + for i in range(wave_size): + # we cannot do this + cur_val = get_graph_node(Extract(src, [i]), graph) + # How to get right offset and src slice? + cur_offset = offset_values[i // wave_size] + offset_val = get_graph_node(binary_fn(cur_offset, cur_val), graph) + out.append(offset_val) + + return out + + +def emit_interwave_scan( + binary_fn, + orig_src, + src, + graph, + trace, + scan_dim, + num_scan_waves, + wg_constraint_map, + hardware_constraint, +): + """ + Steps: + 1) Last active lane of each wave writes its per-wave total to shared. + 2) Wave 0 inclusive-scans these totals, then converts to exclusive offsets. + 3) Each wave reads its exclusive offset and adds it to its scalar. + """ + + lane_id = ( + hardware_constraint.linearized_thread_id % hardware_constraint.threads_per_wave + ) + wave_id = delinearize_index( + hardware_constraint.linearized_thread_id + // hardware_constraint.threads_per_wave, + hardware_constraint.waves_per_block, + ) + scan_wg_dim = wg_constraint_map[scan_dim].workgroup_dim + scan_wave_id = wave_id[scan_wg_dim] # 0...3; [Mod(floor($T0/64), 4), 0, 0] + + src = get_graph_node( + Broadcast(src, target_shape=orig_src.type.symbolic_shape), graph + ) + src.index = orig_src.index + src.index.update({scan_dim: IndexSequence(0, 1, 1)}) + + src_custom = get_custom(src) + total_len = subs_idxc(scan_dim) # 256 + tpw = hardware_constraint.threads_per_wave # 64 + + ### for partial last wave + W_s = get_graph_node(NewScalar(tpw, tkl.i32), graph) + temp_rem_after_start = total_len - (scan_wave_id * tpw) + rem_after_start = get_graph_node(NewScalar(temp_rem_after_start, tkl.i32), graph) + max_cond = get_graph_node(Gt(W_s, rem_after_start), graph) + this_wave_len = get_graph_node( + SelectOp(cond=max_cond, if_true=W_s, if_false=rem_after_start), graph + ) + + # Get last lane id + lane_id_s = get_graph_node(NewScalar(lane_id, tkl.i32), graph) + # get the last lane idx; including case when it is less than 64 + prev_lane_idx = get_graph_node( + Sub(this_wave_len, get_graph_node(NewScalar(1, tkl.i32), graph)), graph + ) + gt_zero = get_graph_node( + Ge(this_wave_len, get_graph_node(NewScalar(1, tkl.i32), graph)), graph + ) # i1 + eq_last = get_graph_node(Eq(lane_id_s, prev_lane_idx), graph) + + # logical AND via multiply + is_last_lane = get_graph_node(Mul(gt_zero, eq_last), graph) + + sums_buf = Allocate( + (scan_dim,), (num_scan_waves,), src_custom.type.dtype, SHARED_ADDRESS_SPACE + ).add_to_graph(graph) + + exec_on_last = fx.Graph() + sub_store = f"store_wave_sum_{src.name}" + + ph_src = get_graph_node(Placeholder.from_fx_node(src), exec_on_last) + ph_src.type = src_custom.type + # ph_src.meta["lifted"] = src + + ph_sums = get_graph_node( + Placeholder.from_fx_node(get_custom(sums_buf)), exec_on_last + ) + ph_sums.type = get_custom(sums_buf).type + ph_sums.meta["lifted"] = sums_buf + + write_sum = Write(ph_src, ph_sums, 1).add_to_graph(exec_on_last) + write_sum.index = {scan_dim: IndexSequence(scan_wave_id, 1, 1)} + + cond_store = get_graph_node( + Conditional( + is_last_lane, + subgraph_name=sub_store, + implicit_captures=[get_outer_node(src), sums_buf], + ), + graph, + ) + exec_on_last.parent_op = cond_store + trace.add_subgraph(sub_store, exec_on_last) + trace.get_root_graph().subgraphs[sub_store] = exec_on_last + + read_totals = Read( + sums_buf, + elements_per_thread=num_scan_waves, + _write_dependency=[cond_store, write_sum], + ).add_to_graph(graph) + + read_totals.index = {scan_dim: IndexSequence(0, 1, 1)} # scan_wave_id + + incl_nested = emit_local_inclusive_scan_block( + binary_fn, [read_totals], graph, num_scan_waves + ) + # TODO: this is just to test; fix accessed offset + # updated_offsets = incl_nested[-1][-1] + updated_offsets = incl_nested[-1] # [0, 64, 128, 192] + updated_offsets = Reshape( + args=updated_offsets, + target_vector_shape={scan_dim: num_scan_waves}, + ).add_to_graph(graph) + + off = get_graph_node(Extract(updated_offsets, [scan_wave_id]), graph) + off = get_graph_node(Broadcast(off, src.type.symbolic_shape), graph) + off.index = src.index + + final_scalar = get_graph_node(binary_fn(off, src), graph) + + return [final_scalar] + + def decompose_scan_ops( trace: CapturedTrace, constraints: list, @@ -233,6 +452,13 @@ def decompose_scan_ops( c for c in constraints if isinstance(c, HardwareConstraint) ) + wave_constraint_map = { + c.dim: c for c in constraints if isinstance(c, WaveConstraint) + } + workgroup_constraint_map = { + c.dim: c for c in constraints if isinstance(c, WorkgroupConstraint) + } + subgroup_size = hardware_constraint.threads_per_wave for node in scan_nodes: @@ -241,7 +467,7 @@ def decompose_scan_ops( raise NotImplementedError(f"ScanOp '{custom}' not supported") with custom.graph.inserting_before(custom.fx_node): - scan_src, scan_acc, scan_dim = node.args + scan_src, scan_acc, scan_dim, block_scan = node.args binary_fn = Add if scan_dim is None: @@ -296,7 +522,7 @@ def decompose_scan_ops( binary_fn, local_scan, custom.graph, local_scan_sizes[0] ) - global_scan = emit_global_scan( + final_scan = emit_global_scan( binary_fn, scan_src[0], local_scan, @@ -307,10 +533,38 @@ def decompose_scan_ops( scan_dim, ) + if block_scan: + # compute num_warps to scan across + num_scan_waves = int( + workgroup_constraint_map[scan_dim].tile_size + // wave_constraint_map[scan_dim].tile_size + ) + + if num_scan_waves > subgroup_size: + raise NotImplementedError( + "The 2nd stage butterfly shuffle reduces the" + "the reduction outputs from all the wave. Hence, can only handle at most " + "threads_per_wave number of warps." + ) + + # Scan and update output between waves, by storing individual wave result into shared memory, + # and then adding it to the rest waves to update it. + final_scan = emit_interwave_scan( + binary_fn, + scan_src[0], + final_scan[0], + custom.graph, + trace, + scan_dim, + num_scan_waves, + workgroup_constraint_map, + hardware_constraint, + ) + # Update the users based on the global scan `reshape` results. for user in custom.users: user.update_arg( - custom.fx_node, global_scan[user.expanded_dims[scan_dim]] + custom.fx_node, final_scan[user.expanded_dims[scan_dim]] ) DCE(trace)