Skip to content

Commit

Permalink
support zero-sized tensor for PointwiseDynamic (FlagOpen#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
iclementine authored Dec 2, 2024
1 parent 980c007 commit 84c192f
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/flag_gems/utils/pointwise_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,9 @@ def gen_task_partition(self, code: IndentedBuffer):
else:
code.writeline("shape = out0.shape")
code.writeline("num_tasks = out0.numel()")
code.writeline("if num_tasks == 0:")
with code.indent():
self.gen_return(code)
max_tile_size = self.config.max_tile_size
code.writeline(
f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)"
Expand All @@ -855,6 +858,9 @@ def gen_task_partition_1d(self, code: IndentedBuffer):
else:
code.writeline("shape = out0.shape")
code.writeline("num_tasks = out0.numel()")
code.writeline("if num_tasks == 0:")
with code.indent():
self.gen_return(code)
max_tile_size = self.config.max_tile_size
code.writeline(
f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)"
Expand Down
47 changes: 47 additions & 0 deletions tests/test_pointwise_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,50 @@ def add(x, y):
y = torch.randn_like(x)
out = add(x, y)
torch.testing.assert_close(out, x + y)


@pytest.mark.parametrize("use_1d_tile", [True, False])
@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER)
def test_dynamic_function_zero_sized_task_unary(use_1d_tile, use_block_pointer):
config = CodeGenConfig(
max_tile_size=1024,
max_grid_size=(65536, 65536, 65536),
max_num_warps_per_cta=32,
prefer_block_pointer=use_block_pointer,
prefer_1d_tile=use_1d_tile,
)

@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")], config=config)
@triton.jit
def f(x):
return x * 2.0

shape = (0, 10)
x = torch.randn(shape, device="cuda")
out = f(x)
torch.testing.assert_close(out, x * 2.0)


@pytest.mark.parametrize("use_1d_tile", [True, False])
@pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER)
def test_dynamic_function_zero_sized_task_binary(use_1d_tile, use_block_pointer):
config = CodeGenConfig(
max_tile_size=1024,
max_grid_size=(65536, 65536, 65536),
max_num_warps_per_cta=32,
prefer_block_pointer=use_block_pointer,
prefer_1d_tile=use_1d_tile,
)

@pointwise_dynamic(
num_inputs=2, promotion_methods=[(0, 1, "DEFAULT")], config=config
)
@triton.jit
def f(x, y):
return x * 2.0 + y

shape = (0, 10)
x = torch.randn(shape, device="cuda")
y = torch.randn_like(x)
out = f(x, y)
torch.testing.assert_close(out, x * 2.0 + y)
20 changes: 20 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,26 @@ def test_accuracy_cat(shape, dim, dtype):
gems_assert_equal(res_out, ref_out)


@pytest.mark.cat
@pytest.mark.parametrize(
"shape, dim",
[
(((0, 3), (2, 3)), 0),
(((0, 3), (0, 3)), 0),
(((0,), (0,)), 0),
],
)
@pytest.mark.parametrize("dtype", [torch.float32])
def test_accuracy_cat_empty_tensor(shape, dim, dtype):
inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape]
ref_inp = [to_reference(_) for _ in inp]
ref_out = torch.cat(ref_inp, dim)

with flag_gems.use_gems():
res_out = torch.cat(inp, dim)
gems_assert_equal(res_out, ref_out)


VSTACK_SHAPES = [
[(3,), (3,)],
[(3, 33), (7, 33)],
Expand Down

0 comments on commit 84c192f

Please sign in to comment.