Skip to content

Commit

Permalink
Fix index put core shpes (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Feb 26, 2025
1 parent 43d3c3f commit 5f1ae2a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 43 deletions.
16 changes: 16 additions & 0 deletions benchmark/core_shapes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,19 @@ AttentionBenchmark:
- [4, 8, 2048, 128]
- [4, 8, 3072, 128]
- [4, 8, 4096, 128]

IndexPutAccFalseBenchmark:
shapes:
- [[268435456,], [[65536,],], [65536,]]
- [[32, 32], [[8,], [2, 8]], [8,]]
- [[1024, 1024], [[4, 64],], [1024,]]
- [[512, 512, 512], [[2, 128], [128,], [128,]], [128,]]
- [[512, 512, 512], [[2, 128],], [512,]]

IndexPutAccTrueBenchmark:
shapes:
- [[268435456,], [[65536,],], [65536,]]
- [[32, 32], [[8,], [8,]], [8,]]
- [[1024, 1024], [[64,], [64,]], [64,]]
- [[512, 512, 512], [[128,], [128,], [128,]], [128,]]
- [[512, 512, 512], [[2, 128], [2, 128], [2, 128]], [2, 128]]
89 changes: 46 additions & 43 deletions benchmark/test_select_and_slice_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,45 +304,46 @@ def inner(shapes, dtype, device):
return inner


@pytest.mark.index_put
def test_index_put_acc_false_perf():
class IndexPutBenchmark(GenericBenchmark):
def set_more_shapes(self):
INDEX_PUT_SHAPE = (
((2**28,), ((2**16,),), (2**16,)),
((32, 32), ((8,), (8,)), (8,)),
((32, 32), ((8,), (2, 8)), (8,)),
((32, 32), ((2, 8),), (32,)),
((1024, 1024), ((64,), (64,)), (64,)),
class IndexPutAccFalseBenchmark(GenericBenchmark):
def set_more_shapes(self):
INDEX_PUT_SHAPE = (
((2**28,), ((2**16,),), (2**16,)),
((32, 32), ((8,), (8,)), (8,)),
((32, 32), ((8,), (2, 8)), (8,)),
((32, 32), ((2, 8),), (32,)),
((1024, 1024), ((64,), (64,)), (64,)),
(
(1024, 1024),
(
(1024, 1024),
(64,),
(
(64,),
(
4,
64,
),
4,
64,
),
(64,),
),
(64,),
),
(
(1024, 1024),
(
(1024, 1024),
(
(
4,
64,
),
4,
64,
),
(1024,),
),
((512, 512, 512), ((128,), (128,), (128,)), (128,)),
((512, 512, 512), ((2, 128), (128,), (128,)), (128,)),
((512, 512, 512), ((2, 128),), (512,)),
)
self.shapes = INDEX_PUT_SHAPE
return None

bench = IndexPutBenchmark(
(1024,),
),
((512, 512, 512), ((128,), (128,), (128,)), (128,)),
((512, 512, 512), ((2, 128), (128,), (128,)), (128,)),
((512, 512, 512), ((2, 128),), (512,)),
)
self.shapes = INDEX_PUT_SHAPE
return None


@pytest.mark.index_put
def test_index_put_acc_false_perf():
bench = IndexPutAccFalseBenchmark(
op_name="index_put",
torch_op=torch.index_put,
input_fn=index_put_input_fn(False),
Expand All @@ -351,20 +352,22 @@ def set_more_shapes(self):
bench.run()


class IndexPutAccTrueBenchmark(GenericBenchmark):
def set_more_shapes(self):
INDEX_PUT_SHAPE = (
((2**28,), ((2**16,),), (2**16,)),
((32, 32), ((8,), (8,)), (8,)),
((1024, 1024), ((64,), (64,)), (64,)),
((512, 512, 512), ((128,), (128,), (128,)), (128,)),
((512, 512, 512), ((2, 128), (2, 128), (2, 128)), (2, 128)),
)
self.shapes = INDEX_PUT_SHAPE
return None


@pytest.mark.index_put
def test_index_put_acc_true_perf():
class IndexPutBenchmark(GenericBenchmark):
def set_more_shapes(self):
INDEX_PUT_SHAPE = (
((2**28,), ((2**16,),), (2**16,)),
((32, 32), ((8,), (8,)), (8,)),
((1024, 1024), ((64,), (64,)), (64,)),
((512, 512, 512), ((128,), (128,), (128,)), (128,)),
)
self.shapes = INDEX_PUT_SHAPE
return None

bench = IndexPutBenchmark(
bench = IndexPutAccTrueBenchmark(
op_name="index_put",
torch_op=torch.index_put,
input_fn=index_put_input_fn(True),
Expand Down

0 comments on commit 5f1ae2a

Please sign in to comment.