Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 62 additions & 13 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,18 +1733,7 @@ def test(


@require_e2e
@pytest.mark.parametrize(
"shape",
[
(1, 27),
(1, 64),
(51, 64),
(128, 64),
(1, 256),
(1, 512),
(64, 500),
],
)
@pytest.mark.parametrize("shape", get_test_shapes("test_scanop_cumsum"))
def test_scanop_cumsum(shape, run_bench):
M = tkl.sym.M
N = tkl.sym.N
Expand Down Expand Up @@ -1774,7 +1763,6 @@ def test(
res = tkw.cumsum(lhs, dim=N)
tkw.write(res, c)

torch.manual_seed(1)
input = device_randint(low=1, high=5, size=shape, dtype=torch.int32)
output = device_zeros(shape, dtype=torch.int32)
torch_ref = torch.cumsum((input), dim=-1, dtype=torch.int32)
Expand All @@ -1794,6 +1782,67 @@ def test(
assert_close(torch_ref, output, atol=1e-03, rtol=1e-05)


@require_e2e
@pytest.mark.parametrize("shape", [(256, 256)])
def test_block_scanop_cumsum(shape, run_bench):
round_to_divisible = lambda src, denom: sympy.ceiling(src / denom) * denom
M = tkl.sym.M
N = tkl.sym.N
wave_size = 64
num_waves = 4
BLOCK_M = 1

# Distribute N dim across num_waves, and pad to disivible by wave_size.
ELEMS_PER_WAVE = round_to_divisible(sympy.ceiling(N / num_waves), wave_size)
# Minimum number of elems per wave should be size of wave.
ELEMS_PER_WAVE = sympy.Max(ELEMS_PER_WAVE, wave_size)
BLOCK_N = ELEMS_PER_WAVE * num_waves
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=wave_size,
vector_shapes={M: 1, N: ELEMS_PER_WAVE},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, ELEMS_PER_WAVE)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32],
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.i32],
):
lhs = tkw.read(a)
res = tkw.cumsum(lhs, dim=N, block=True)
tkw.write(res, c)

# input = device_randint(low=1, high=5, size=shape, dtype=torch.int32)
input = device_ones(shape, dtype=torch.int32)
output = device_zeros(shape, dtype=torch.int32)
torch_ref = torch.cumsum((input), dim=-1, dtype=torch.int32)
options = WaveCompileOptions(
subs={
M: shape[0],
N: shape[1],
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
},
canonicalize=True,
wave_runtime=True,
run_bench=run_bench,
# print_ir_after=["all"],
# print_ir_before=["all"],
)
options = set_default_run_config(options)
test = wave_compile(options, test)

test(input, output)
breakpoint()
assert_close(torch_ref, output, atol=1e-03, rtol=1e-05)


@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_vector_add")[:2])
@param_bool("use_buffer_ops", "buf_ops")
Expand Down
27 changes: 16 additions & 11 deletions wave_lang/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,17 +447,20 @@ def new_function(*args: Any, **kwargs: dict[str, Any]):

def get_custom(node: fx.Node) -> "CustomOp":
"""Get the corresponding CustomOp for a given fx.Node."""
if not isinstance(node, fx.Node):
raise ValueError(f"Expected an fx.Node but got {type(node)}")

# If the node was created as a CustomOp it has a corresponding field
if hasattr(node, "tkw_op"):
return node.tkw_op.from_fx_node(node)
if node.op == "placeholder":
return Placeholder.from_fx_node(node)
if node.op == "output":
return Output.from_fx_node(node)
return Unknown.from_fx_node(node)
try:
if not isinstance(node, fx.Node):
raise ValueError(f"Expected an fx.Node but got {type(node)}")

# If the node was created as a CustomOp it has a corresponding field
if hasattr(node, "tkw_op"):
return node.tkw_op.from_fx_node(node)
if node.op == "placeholder":
return Placeholder.from_fx_node(node)
if node.op == "output":
return Output.from_fx_node(node)
return Unknown.from_fx_node(node)
except:
breakpoint()


def has_same_custom_type(lhs_type: Memory, rhs_type: Memory) -> bool:
Expand Down Expand Up @@ -2403,11 +2406,13 @@ class ScanOp(CustomOp, ABC):
arg: Source tensor/value to scan.
init: Optional initial value.
dim: Symbolic dimension along which to scan.
block: When set to true, scan across block, else scan across warp.
"""

arg: fx.Node | list[fx.Node]
init: Optional[fx.Node] = None
dim: Optional[IndexSymbol] = None
block: Optional[bool] = False

@property
def indexing_dims(self) -> list[IndexSymbol]:
Expand Down
61 changes: 44 additions & 17 deletions wave_lang/kernel/wave/codegen/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,21 +1556,45 @@ def handle_workgroup_barrier(emitter: WaveEmitter, node: fx.Node):
@handle_op(extract)
def handle_extract(emitter: WaveEmitter, node: fx.Node):
try:
register, offset = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
assert isinstance(offset, list) and len(offset) == 1
extract_vector = cast_vector(emitter, register)
result_type = VectorType.get([1], extract_vector.type.element_type)
# Instead of using `extract_strided_slice` op, we use `extract` + `splat`
# to construct the result vector, to enable more opportunities for them to
# be fused with nearby elementwise and memory ops.
element = vector_d.extract(
extract_vector, static_position=offset, dynamic_position=[]
)
element = vector_d.broadcast(result_type, element)
try:
register, offset = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
assert isinstance(offset, list) and len(offset) == 1
extract_vector = cast_vector(emitter, register)
result_type = VectorType.get([1], extract_vector.type.element_type)
# Instead of using `extract_strided_slice` op, we use `extract` + `splat`
# to construct the result vector, to enable more opportunities for them to
# be fused with nearby elementwise and memory ops.
element = vector_d.extract(
extract_vector, static_position=offset, dynamic_position=[]
)
element = vector_d.broadcast(result_type, element)

emitter.bind_node_proxy(node, IRProxyValue(element))
except:
try:
register, offset = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
assert isinstance(offset, list) and len(offset) == 1
extract_vector = cast_vector(emitter, register)
result_type = VectorType.get([1], extract_vector.type.element_type)
# Instead of using `extract_strided_slice` op, we use `extract` + `splat`
# to construct the result vector, to enable more opportunities for them to
# be fused with nearby elementwise and memory ops.
from iree.compiler import ir
dyn = ir.ShapedType.get_dynamic_size()
dynamic_position=[gen_sympy_index(add_emitter_subs(emitter), x) for x in offset]

element = vector_d.extract(
extract_vector, dynamic_position=dynamic_position, static_position=[dyn,],
)
breakpoint()

emitter.bind_node_proxy(node, IRProxyValue(element))
element = vector_d.broadcast(result_type, element)

emitter.bind_node_proxy(node, IRProxyValue(element))


@handle_op(extract_slice)
Expand Down Expand Up @@ -1644,9 +1668,12 @@ def handle_broadcast(emitter: WaveEmitter, node: fx.Node):
emitter.bind_node_proxy(node, IRProxyValue(vector_src))
return

assert (
vector_type.shape[0] == 1
), f"expected vector_type.shape[0] == 1 but got {vector_type}"
try:
assert (
vector_type.shape[0] == 1
), f"expected vector_type.shape[0] == 1 but got {vector_type}"
except:
breakpoint()

# Extract and Splat
# If by chance broadcast size matches current size, we can return src.
Expand Down
Loading
Loading